From 0549f5cafccae8ad2262c1247108e6d6a268d094 Mon Sep 17 00:00:00 2001 From: Acbox Date: Sun, 12 Apr 2026 17:25:10 +0800 Subject: [PATCH] feat(command): improve slash command UX Make slash commands easier to navigate in chat by splitting help into levels, compacting list output, and surfacing current selections for model, search, memory, and browser settings. Also route /status to the active conversation session and add an access inspector so users can understand their current command and ACL context. --- cmd/agent/main.go | 9 ++ cmd/memoh/serve.go | 9 ++ internal/channel/inbound/channel.go | 107 +++++++++++++- internal/channel/inbound/channel_test.go | 135 +++++++++++++++++ internal/command/access.go | 78 ++++++++++ internal/command/browser.go | 43 +++++- internal/command/commands.go | 1 + internal/command/email_cmd.go | 6 +- internal/command/formatter.go | 46 ++++-- internal/command/handler.go | 70 +++++++-- internal/command/handler_test.go | 177 ++++++++++++++++++++++- internal/command/heartbeat_cmd.go | 2 +- internal/command/interfaces.go | 26 +++- internal/command/mcp.go | 2 +- internal/command/memory.go | 45 +++++- internal/command/model.go | 160 ++++++++++++++++++-- internal/command/parser_test.go | 2 + internal/command/registry.go | 90 ++++++++++-- internal/command/search.go | 45 +++++- internal/command/settings.go | 8 + internal/command/status.go | 151 +++++++++++-------- internal/command/usage.go | 6 + 22 files changed, 1080 insertions(+), 138 deletions(-) create mode 100644 internal/command/access.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 0aa3eefd..8c898a14 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -604,6 +604,7 @@ func provideChannelRouter( emailOutboxService, heartbeatService, queries, + aclService, &commandSkillLoaderAdapter{handler: containerdHandler}, &commandContainerFSAdapter{manager: manager}, )) @@ -779,6 +780,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID, return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil } +func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) { + sess, err := a.svc.GetActiveForRoute(ctx, routeID) + if err != nil { + return inbound.SessionResult{}, err + } + return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil +} + func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) { sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType) if err != nil { diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 3a36cc41..75a608ac 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -492,6 +492,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc emailOutboxService, heartbeatService, queries, + aclService, &commandSkillLoaderAdapter{handler: containerdHandler}, &commandContainerFSAdapter{manager: manager}, )) @@ -877,6 +878,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID, return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil } +func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) { + sess, err := a.svc.GetActiveForRoute(ctx, routeID) + if err != nil { + return inbound.SessionResult{}, err + } + return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil +} + func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) { sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType) if err != nil { diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index 3c20b8b0..ced6c869 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -71,6 +71,7 @@ type ttsModelResolver interface { // SessionEnsurer resolves or creates an active session for a route. type SessionEnsurer interface { EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error) + GetActiveSession(ctx context.Context, routeID string) (SessionResult, error) // CreateNewSession always creates a fresh session and sets it as the // active session for the given route, replacing any previous one. // sessionType defaults to "chat" if empty. @@ -298,11 +299,23 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel if isStopCommand(cmdText) && isDirectedAtBot(msg) { return p.handleStopCommand(ctx, cfg, msg, sender, identity) } + if isStatusCommand(cmdText) && isDirectedAtBot(msg) { + return p.handleStatusCommand(ctx, cfg, msg, sender, identity) + } // Skip generic command handler for mode-prefix commands (/btw, /now, /next) // so they pass through to mode detection below. if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) { - reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText) + reply, err := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{ + BotID: strings.TrimSpace(identity.BotID), + ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID), + UserID: strings.TrimSpace(identity.UserID), + Text: cmdText, + ChannelType: msg.Channel.String(), + ConversationType: strings.TrimSpace(msg.Conversation.Type), + ConversationID: strings.TrimSpace(msg.Conversation.ID), + ThreadID: extractThreadID(msg), + }) if err != nil { reply = "Error: " + err.Error() } @@ -2489,6 +2502,18 @@ func isNewSessionCommand(cmdText string) bool { return parsed.Resource == "new" } +func isStatusCommand(cmdText string) bool { + extracted := command.ExtractCommandText(cmdText) + if extracted == "" { + return false + } + parsed, err := command.Parse(extracted) + if err != nil { + return false + } + return parsed.Resource == "status" +} + // resolveNewSessionType determines the session type for /new command. // /new chat → chat, /new discuss → discuss, /new (no arg) → default by context. // WebUI (local channel) always defaults to chat. @@ -2611,3 +2636,83 @@ func (p *ChannelInboundProcessor) handleNewSessionCommand( Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)}, }) } + +func (p *ChannelInboundProcessor) handleStatusCommand( + ctx context.Context, + cfg channel.ChannelConfig, + msg channel.InboundMessage, + sender channel.StreamReplySender, + identity InboundIdentity, +) error { + target := strings.TrimSpace(msg.ReplyTarget) + if target == "" { + return errors.New("reply target missing for /status command") + } + if p.routeResolver == nil { + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: "Error: route resolver not configured."}, + }) + } + if p.commandHandler == nil { + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: "Error: command handler not configured."}, + }) + } + + threadID := extractThreadID(msg) + routeMetadata := buildRouteMetadata(msg, identity) + p.enrichConversationAvatar(ctx, cfg, msg, routeMetadata) + resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{ + BotID: identity.BotID, + Platform: msg.Channel.String(), + ConversationID: msg.Conversation.ID, + ThreadID: threadID, + ConversationType: msg.Conversation.Type, + ChannelIdentityID: identity.UserID, + ChannelConfigID: identity.ChannelConfigID, + ReplyTarget: target, + Metadata: routeMetadata, + }) + if err != nil { + if p.logger != nil { + p.logger.Warn("resolve route for /status command failed", slog.Any("error", err)) + } + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: "Error: failed to resolve conversation route."}, + }) + } + + sessionID := "" + if p.sessionEnsurer != nil { + sess, sessErr := p.sessionEnsurer.GetActiveSession(ctx, resolved.RouteID) + if sessErr == nil { + sessionID = strings.TrimSpace(sess.ID) + } else if p.logger != nil { + p.logger.Debug("resolve active session for /status command failed", slog.Any("error", sessErr)) + } + } + + reply, execErr := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{ + BotID: strings.TrimSpace(identity.BotID), + ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID), + UserID: strings.TrimSpace(identity.UserID), + Text: rawTextForCommand(msg, strings.TrimSpace(msg.Message.PlainText())), + ChannelType: msg.Channel.String(), + ConversationType: strings.TrimSpace(msg.Conversation.Type), + ConversationID: strings.TrimSpace(msg.Conversation.ID), + ThreadID: threadID, + RouteID: strings.TrimSpace(resolved.RouteID), + SessionID: sessionID, + }) + if execErr != nil { + reply = "Error: " + execErr.Error() + } + + return sender.Send(ctx, channel.OutboundMessage{ + Target: target, + Message: channel.Message{Text: reply}, + }) +} diff --git a/internal/channel/inbound/channel_test.go b/internal/channel/inbound/channel_test.go index fe00e9fc..9cbb1ae5 100644 --- a/internal/channel/inbound/channel_test.go +++ b/internal/channel/inbound/channel_test.go @@ -13,11 +13,15 @@ import ( "strings" "testing" + "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/acl" "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/command" "github.com/memohai/memoh/internal/conversation" + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/media" messagepkg "github.com/memohai/memoh/internal/message" "github.com/memohai/memoh/internal/schedule" @@ -185,6 +189,71 @@ type fakeChatACL struct { lastReq acl.EvaluateRequest } +type fakeSessionEnsurer struct { + activeSession SessionResult + activeErr error + lastRouteID string +} + +func (f *fakeSessionEnsurer) EnsureActiveSession(_ context.Context, _, routeID, _ string) (SessionResult, error) { + f.lastRouteID = routeID + if f.activeErr != nil { + return SessionResult{}, f.activeErr + } + return f.activeSession, nil +} + +func (f *fakeSessionEnsurer) GetActiveSession(_ context.Context, routeID string) (SessionResult, error) { + f.lastRouteID = routeID + if f.activeErr != nil { + return SessionResult{}, f.activeErr + } + return f.activeSession, nil +} + +func (f *fakeSessionEnsurer) CreateNewSession(_ context.Context, _, routeID, _, _ string) (SessionResult, error) { + f.lastRouteID = routeID + if f.activeErr != nil { + return SessionResult{}, f.activeErr + } + return f.activeSession, nil +} + +type fakeCommandQueries struct { + messageCount int64 + usage int64 + cacheRow dbsqlc.GetSessionCacheStatsRow + skills []string +} + +func (*fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) { + return pgtype.UUID{}, errors.New("unexpected latest session lookup") +} + +func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) { + return f.messageCount, nil +} + +func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) { + return f.usage, nil +} + +func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) { + return f.cacheRow, nil +} + +func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) { + return f.skills, nil +} + +func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) { + return nil, nil +} + +func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) { + return nil, nil +} + func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) { f.calls++ f.lastReq = req @@ -548,6 +617,72 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { } } +func TestChannelInboundProcessorStatusUsesRouteSession(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-status"}} + policySvc := &fakePolicyService{} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-status", RouteID: "route-status"}} + gateway := &fakeChatGateway{} + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0) + processor.SetSessionEnsurer(&fakeSessionEnsurer{ + activeSession: SessionResult{ID: "11111111-1111-1111-1111-111111111111", Type: "chat"}, + }) + processor.SetCommandHandler(command.NewHandler( + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + &fakeCommandQueries{ + messageCount: 9, + usage: 512, + cacheRow: dbsqlc.GetSessionCacheStatsRow{ + CacheReadTokens: 64, + CacheWriteTokens: 32, + TotalInputTokens: 512, + }, + skills: []string{"search"}, + }, + nil, + nil, + nil, + )) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("discord")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("discord"), + Message: channel.Message{Text: "/status"}, + ReplyTarget: "discord:status", + Sender: channel.Identity{SubjectID: "user-1"}, + Conversation: channel.Conversation{ + ID: "conv-status", + Type: channel.ConversationTypePrivate, + }, + } + + if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(sender.sent) != 1 { + t.Fatalf("expected one status reply, got %d", len(sender.sent)) + } + if !strings.Contains(sender.sent[0].Message.Text, "- Scope: current conversation") { + t.Fatalf("expected current conversation scope, got %q", sender.sent[0].Message.Text) + } + if !strings.Contains(sender.sent[0].Message.Text, "- Session ID: 11111111-1111-1111-1111-111111111111") { + t.Fatalf("expected active route session in reply, got %q", sender.sent[0].Message.Text) + } +} + func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) { t.Parallel() diff --git a/internal/command/access.go b/internal/command/access.go new file mode 100644 index 00000000..5dbd764c --- /dev/null +++ b/internal/command/access.go @@ -0,0 +1,78 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/memohai/memoh/internal/acl" +) + +func (h *Handler) buildAccessGroup() *CommandGroup { + g := newCommandGroup("access", "Inspect identity and permission context") + g.DefaultAction = "show" + g.Register(SubCommand{ + Name: "show", + Usage: "show - Show current identity, write access, and chat ACL context", + Handler: func(cc CommandContext) (string, error) { + writeAccess := "no" + if cc.Role == "owner" { + writeAccess = "yes" + } + + pairs := []kv{ + {"Channel Identity", fallbackValue(cc.ChannelIdentityID)}, + {"Linked User", fallbackValue(cc.UserID)}, + {"Bot Role", fallbackValue(cc.Role)}, + {"Write Commands", writeAccess}, + {"Channel", fallbackValue(cc.ChannelType)}, + {"Conversation Type", fallbackValue(cc.ConversationType)}, + {"Conversation ID", fallbackValue(cc.ConversationID)}, + {"Thread ID", fallbackValue(cc.ThreadID)}, + } + if strings.TrimSpace(cc.RouteID) != "" { + pairs = append(pairs, kv{"Route ID", cc.RouteID}) + } + if strings.TrimSpace(cc.SessionID) != "" { + pairs = append(pairs, kv{"Session ID", cc.SessionID}) + } + + aclStatus := "unavailable" + if h.aclEvaluator != nil && strings.TrimSpace(cc.ChannelType) != "" { + allowed, err := h.aclEvaluator.Evaluate(cc.Ctx, acl.EvaluateRequest{ + BotID: cc.BotID, + ChannelIdentityID: cc.ChannelIdentityID, + ChannelType: cc.ChannelType, + SourceScope: acl.SourceScope{ + ConversationType: cc.ConversationType, + ConversationID: cc.ConversationID, + ThreadID: cc.ThreadID, + }, + }) + switch { + case err != nil: + aclStatus = "error: " + err.Error() + case allowed: + aclStatus = "allow" + default: + aclStatus = "deny" + } + } + pairs = append(pairs, kv{"Chat ACL", aclStatus}) + + return formatKV(pairs), nil + }, + }) + return g +} + +func fallbackValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "(none)" + } + return value +} + +func formatChangedValue(label, before, after string) string { + return fmt.Sprintf("%s: %s -> %s", label, fallbackValue(before), fallbackValue(after)) +} diff --git a/internal/command/browser.go b/internal/command/browser.go index d5de3a4d..bc4b77ac 100644 --- a/internal/command/browser.go +++ b/internal/command/browser.go @@ -13,6 +13,9 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { Name: "list", Usage: "list - List all browser contexts", Handler: func(cc CommandContext) (string, error) { + if h.browserCtxService == nil { + return "Browser context service is not available.", nil + } items, err := h.browserCtxService.List(cc.Ctx) if err != nil { return "", err @@ -20,13 +23,37 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { if len(items) == 0 { return "No browser contexts found.", nil } - records := make([][]kv, 0, len(items)) + settingsResp, _ := h.getBotSettings(cc) + currentRecords := make([][]kv, 0, 1) + otherRecords := make([][]kv, 0, len(items)) for _, item := range items { - records = append(records, []kv{ - {"Name", item.Name}, - }) + label := item.Name + record := []kv{{"Name", label}} + if item.ID == settingsResp.BrowserContextID { + label += " [current]" + record[0].value = label + currentRecords = append(currentRecords, record) + continue + } + otherRecords = append(otherRecords, record) } - return formatItems(records), nil + currentRecords = append(currentRecords, otherRecords...) + records := currentRecords + return formatLimitedItems(records, defaultListLimit, "Use /browser current to inspect the active context."), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show the current browser context", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{{"Browser Context", h.resolveBrowserContextName(cc, settingsResp.BrowserContextID)}}), nil }, }) g.Register(SubCommand{ @@ -37,7 +64,11 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { if len(cc.Args) < 1 { return "Usage: /browser set ", nil } + if h.settingsService == nil { + return "Settings service is not available.", nil + } name := cc.Args[0] + before, _ := h.getBotSettings(cc) items, err := h.browserCtxService.List(cc.Ctx) if err != nil { return "", err @@ -50,7 +81,7 @@ func (h *Handler) buildBrowserGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Browser context set to %q.", item.Name), nil + return formatChangedValue("Browser context", h.resolveBrowserContextName(cc, before.BrowserContextID), item.Name), nil } } return fmt.Sprintf("Browser context %q not found.", name), nil diff --git a/internal/command/commands.go b/internal/command/commands.go index b40b3513..919f968a 100644 --- a/internal/command/commands.go +++ b/internal/command/commands.go @@ -16,5 +16,6 @@ func (h *Handler) buildRegistry() *Registry { r.RegisterGroup(h.buildSkillGroup()) r.RegisterGroup(h.buildFSGroup()) r.RegisterGroup(h.buildStatusGroup()) + r.RegisterGroup(h.buildAccessGroup()) return r } diff --git a/internal/command/email_cmd.go b/internal/command/email_cmd.go index 951f70c2..43fddbba 100644 --- a/internal/command/email_cmd.go +++ b/internal/command/email_cmd.go @@ -24,7 +24,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup { {"Provider", item.Provider}, }) } - return formatItems(records), nil + return formatLimitedItems(records, defaultListLimit, "Use /email bindings to inspect bot bindings."), nil }, }) g.Register(SubCommand{ @@ -46,7 +46,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup { {"Permissions", perms}, }) } - return formatItems(records), nil + return formatLimitedItems(records, defaultListLimit, "Use /email outbox to inspect recent sends."), nil }, }) g.Register(SubCommand{ @@ -70,7 +70,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup { {"Sent", item.SentAt.Format("01-02 15:04")}, }) } - return formatItems(records), nil + return formatLimitedItems(records, 10, "Use the Web UI for older outbox entries."), nil }, }) return g diff --git a/internal/command/formatter.go b/internal/command/formatter.go index 0b03996f..1b626ce9 100644 --- a/internal/command/formatter.go +++ b/internal/command/formatter.go @@ -6,19 +6,10 @@ import ( "unicode/utf8" ) +const defaultListLimit = 12 + // formatItems renders a list of records as a Markdown-style list. -// Each record is a slice of kv pairs; the first pair's value is used as the -// bullet title, and subsequent pairs are indented beneath it. -// -// Example output: -// -// - mybot -// Description: A helpful assistant -// ID: abc123 -// -// - another -// Description: Something else -// ID: def456 +// Each record is rendered on a single line so long lists stay readable in IM. func formatItems(items [][]kv) string { if len(items) == 0 { return "" @@ -31,14 +22,41 @@ func formatItems(items [][]kv) string { if i > 0 { b.WriteByte('\n') } - fmt.Fprintf(&b, "- %s\n", record[0].value) + fmt.Fprintf(&b, "- %s", record[0].value) + extras := make([]string, 0, len(record)-1) for _, pair := range record[1:] { - fmt.Fprintf(&b, " %s: %s\n", pair.key, pair.value) + if strings.TrimSpace(pair.value) == "" { + continue + } + extras = append(extras, fmt.Sprintf("%s: %s", pair.key, pair.value)) + } + if len(extras) > 0 { + fmt.Fprintf(&b, " | %s", strings.Join(extras, " | ")) } } return b.String() } +func formatLimitedItems(items [][]kv, limit int, hint string) string { + if len(items) == 0 { + return "" + } + if limit <= 0 { + limit = defaultListLimit + } + total := len(items) + if total <= limit { + return formatItems(items) + } + shown := items[:limit] + result := formatItems(shown) + suffix := fmt.Sprintf("Showing %d of %d items.", len(shown), total) + if strings.TrimSpace(hint) != "" { + suffix += " " + strings.TrimSpace(hint) + } + return result + "\n\n" + suffix +} + // formatKV renders key-value pairs as a simple Markdown list. // // Example output: diff --git a/internal/command/handler.go b/internal/command/handler.go index ecc510be..810d2ab1 100644 --- a/internal/command/handler.go +++ b/internal/command/handler.go @@ -4,10 +4,10 @@ import ( "context" "fmt" "log/slog" + "strings" "github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/browsercontexts" - dbsqlc "github.com/memohai/memoh/internal/db/sqlc" emailpkg "github.com/memohai/memoh/internal/email" "github.com/memohai/memoh/internal/heartbeat" "github.com/memohai/memoh/internal/mcp" @@ -56,13 +56,28 @@ type Handler struct { emailService *emailpkg.Service emailOutboxService *emailpkg.OutboxService heartbeatService *heartbeat.Service - queries *dbsqlc.Queries + queries CommandQueries + aclEvaluator AccessEvaluator skillLoader SkillLoader containerFS ContainerFS logger *slog.Logger } +// ExecuteInput carries the caller identity and channel context for command execution. +type ExecuteInput struct { + BotID string + ChannelIdentityID string + UserID string + Text string + ChannelType string + ConversationType string + ConversationID string + ThreadID string + RouteID string + SessionID string +} + // NewHandler creates a Handler with all required services. func NewHandler( log *slog.Logger, @@ -78,7 +93,8 @@ func NewHandler( emailService *emailpkg.Service, emailOutboxService *emailpkg.OutboxService, heartbeatService *heartbeat.Service, - queries *dbsqlc.Queries, + queries CommandQueries, + aclEvaluator AccessEvaluator, skillLoader SkillLoader, containerFS ContainerFS, ) *Handler { @@ -99,6 +115,7 @@ func NewHandler( emailOutboxService: emailOutboxService, heartbeatService: heartbeatService, queries: queries, + aclEvaluator: aclEvaluator, skillLoader: skillLoader, containerFS: containerFS, logger: log.With(slog.String("component", "command")), @@ -140,7 +157,16 @@ func (h *Handler) IsCommand(text string) bool { // Execute parses and runs a slash command, returning the text reply. func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) { - cmdText := ExtractCommandText(text) + return h.ExecuteWithInput(ctx, ExecuteInput{ + BotID: botID, + ChannelIdentityID: channelIdentityID, + Text: text, + }) +} + +// ExecuteWithInput parses and runs a slash command with channel/session context. +func (h *Handler) ExecuteWithInput(ctx context.Context, input ExecuteInput) (string, error) { + cmdText := ExtractCommandText(input.Text) if cmdText == "" { return h.registry.GlobalHelp(), nil } @@ -151,12 +177,16 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st // Resolve the user's role in this bot. role := "" - if h.roleResolver != nil && channelIdentityID != "" { - r, err := h.roleResolver.GetMemberRole(ctx, botID, channelIdentityID) + roleIdentityID := input.ChannelIdentityID + if strings.TrimSpace(input.UserID) != "" { + roleIdentityID = strings.TrimSpace(input.UserID) + } + if h.roleResolver != nil && roleIdentityID != "" { + r, err := h.roleResolver.GetMemberRole(ctx, input.BotID, roleIdentityID) if err != nil { h.logger.Warn("failed to resolve member role", - slog.String("bot_id", botID), - slog.String("channel_identity_id", channelIdentityID), + slog.String("bot_id", input.BotID), + slog.String("role_identity_id", roleIdentityID), slog.Any("error", err), ) } else { @@ -165,15 +195,29 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st } cc := CommandContext{ - Ctx: ctx, - BotID: botID, - Role: role, - Args: parsed.Args, + Ctx: ctx, + BotID: input.BotID, + Role: role, + Args: parsed.Args, + ChannelIdentityID: strings.TrimSpace(input.ChannelIdentityID), + UserID: strings.TrimSpace(input.UserID), + ChannelType: strings.TrimSpace(input.ChannelType), + ConversationType: strings.TrimSpace(input.ConversationType), + ConversationID: strings.TrimSpace(input.ConversationID), + ThreadID: strings.TrimSpace(input.ThreadID), + RouteID: strings.TrimSpace(input.RouteID), + SessionID: strings.TrimSpace(input.SessionID), } // /help if parsed.Resource == "help" { - return h.registry.GlobalHelp(), nil + if parsed.Action == "" { + return h.registry.GlobalHelp(), nil + } + if len(parsed.Args) == 0 { + return h.registry.GroupHelp(parsed.Action), nil + } + return h.registry.ActionHelp(parsed.Action, parsed.Args[0]), nil } // Top-level commands (e.g. /new) are handled by the channel inbound diff --git a/internal/command/handler_test.go b/internal/command/handler_test.go index f8205c4f..8cb79c73 100644 --- a/internal/command/handler_test.go +++ b/internal/command/handler_test.go @@ -5,6 +5,10 @@ import ( "strings" "testing" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/schedule" "github.com/memohai/memoh/internal/settings" @@ -25,9 +29,58 @@ type fakeScheduleService struct { items []schedule.Schedule } +type fakeCommandQueries struct { + latestSessionID pgtype.UUID + latestSessionErr error + messageCount int64 + latestUsage int64 + latestUsageErr error + cacheRow dbsqlc.GetSessionCacheStatsRow + cacheErr error + skills []string +} + +func (f *fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) { + return f.latestSessionID, f.latestSessionErr +} + +func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) { + return f.messageCount, nil +} + +func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) { + if f.latestUsageErr != nil { + return 0, f.latestUsageErr + } + return f.latestUsage, nil +} + +func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) { + if f.cacheErr != nil { + return dbsqlc.GetSessionCacheStatsRow{}, f.cacheErr + } + return f.cacheRow, nil +} + +func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) { + return f.skills, nil +} + +func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) { + return nil, nil +} + +func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) { + return nil, nil +} + // newTestHandler creates a Handler with nil services for use in tests. func newTestHandler(roleResolver MemberRoleResolver) *Handler { - return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) +} + +func newTestHandlerWithQueries(roleResolver MemberRoleResolver, queries CommandQueries) *Handler { + return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, queries, nil, nil, nil) } // --- tests --- @@ -73,6 +126,42 @@ func TestExecute_Help(t *testing.T) { if !strings.Contains(result, "Available commands") { t.Errorf("expected help text, got: %s", result) } + if strings.Contains(result, "set-heartbeat") { + t.Errorf("top-level help should not expand nested actions, got: %s", result) + } + if !strings.Contains(result, "- /model - Manage bot models") { + t.Errorf("expected top-level model entry, got: %s", result) + } +} + +func TestExecute_HelpGroup(t *testing.T) { + t.Parallel() + h := newTestHandler(&fakeRoleResolver{role: "owner"}) + result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "/model - Manage bot models") { + t.Errorf("expected group help, got: %s", result) + } + if !strings.Contains(result, "- set - Set the chat model [owner]") { + t.Errorf("expected compact action summary, got: %s", result) + } +} + +func TestExecute_HelpAction(t *testing.T) { + t.Parallel() + h := newTestHandler(&fakeRoleResolver{role: "owner"}) + result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model set") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "Usage: /model set | ") { + t.Errorf("expected action usage, got: %s", result) + } + if !strings.Contains(result, "Access: owner only") { + t.Errorf("expected owner hint, got: %s", result) + } } func TestExecute_UnknownCommand(t *testing.T) { @@ -207,8 +296,8 @@ func TestFormatItems(t *testing.T) { if !strings.Contains(result, "- foo") { t.Errorf("expected '- foo' bullet, got: %s", result) } - if !strings.Contains(result, " Type: bar") { - t.Errorf("expected indented 'Type: bar', got: %s", result) + if !strings.Contains(result, "- foo | Type: bar") { + t.Errorf("expected compact line entry, got: %s", result) } if !strings.Contains(result, "- longname") { t.Errorf("expected '- longname' bullet, got: %s", result) @@ -255,7 +344,7 @@ func TestGlobalHelp_AllGroups(t *testing.T) { for _, group := range []string{ "schedule", "mcp", "settings", "model", "memory", "search", "browser", "usage", - "email", "heartbeat", "skill", "fs", + "email", "heartbeat", "skill", "fs", "access", } { if !strings.Contains(help, "/"+group) { t.Errorf("missing /%s in global help", group) @@ -263,6 +352,86 @@ func TestGlobalHelp_AllGroups(t *testing.T) { } } +func TestExecuteWithInput_Access(t *testing.T) { + t.Parallel() + h := newTestHandler(&fakeRoleResolver{role: "owner"}) + result, err := h.ExecuteWithInput(context.Background(), ExecuteInput{ + BotID: "bot-1", + ChannelIdentityID: "channel-id-1", + UserID: "user-id-1", + Text: "/access", + ChannelType: "discord", + ConversationType: "thread", + ConversationID: "conv-1", + ThreadID: "thread-1", + RouteID: "route-1", + SessionID: "session-1", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "- Channel Identity: channel-id-1") { + t.Errorf("expected channel identity in access output, got: %s", result) + } + if !strings.Contains(result, "- Write Commands: yes") { + t.Errorf("expected write access in access output, got: %s", result) + } +} + +func TestExecute_StatusLatest(t *testing.T) { + t.Parallel() + sessionUUID := pgtype.UUID{} + copy(sessionUUID.Bytes[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) + sessionUUID.Valid = true + h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{ + latestSessionID: sessionUUID, + messageCount: 42, + latestUsage: 1200, + cacheRow: dbsqlc.GetSessionCacheStatsRow{ + CacheReadTokens: 300, + CacheWriteTokens: 150, + TotalInputTokens: 1200, + }, + skills: []string{"search", "browser"}, + }) + result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "- Scope: latest bot session") { + t.Errorf("expected latest scope, got: %s", result) + } + if !strings.Contains(result, "- Messages: 42") { + t.Errorf("expected message count, got: %s", result) + } +} + +func TestExecute_StatusLatestNoRows(t *testing.T) { + t.Parallel() + h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{ + latestSessionErr: pgx.ErrNoRows, + }) + result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "No session found for this bot.") { + t.Errorf("expected no session message, got: %s", result) + } +} + +func TestExecute_StatusShowWithoutSession(t *testing.T) { + t.Parallel() + h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{}) + result, err := h.Execute(context.Background(), "bot-1", "user-1", "/status") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(result, "No active session found for this conversation.") { + t.Errorf("expected route-aware no session message, got: %s", result) + } +} + // Verify write commands are tagged with [owner] in usage. func TestUsage_OwnerTag(t *testing.T) { t.Parallel() diff --git a/internal/command/heartbeat_cmd.go b/internal/command/heartbeat_cmd.go index e57cf7b7..c49ce2bd 100644 --- a/internal/command/heartbeat_cmd.go +++ b/internal/command/heartbeat_cmd.go @@ -40,7 +40,7 @@ func (h *Handler) buildHeartbeatGroup() *CommandGroup { } records = append(records, rec) } - return formatItems(records), nil + return formatLimitedItems(records, 10, "Use the Web UI for older heartbeat logs."), nil }, }) return g diff --git a/internal/command/interfaces.go b/internal/command/interfaces.go index 3f3db905..bb2e0395 100644 --- a/internal/command/interfaces.go +++ b/internal/command/interfaces.go @@ -1,6 +1,13 @@ package command -import "context" +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/memohai/memoh/internal/acl" + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" +) // Skill represents a single skill loaded from a bot's container. type Skill struct { @@ -25,3 +32,20 @@ type ContainerFS interface { ListDir(ctx context.Context, botID, path string) ([]FSEntry, error) ReadFile(ctx context.Context, botID, path string) (string, error) } + +// CommandQueries captures the sqlc methods used by slash commands. +// *dbsqlc.Queries satisfies this interface directly. +type CommandQueries interface { + GetLatestSessionIDByBot(ctx context.Context, botID pgtype.UUID) (pgtype.UUID, error) + CountMessagesBySession(ctx context.Context, sessionID pgtype.UUID) (int64, error) + GetLatestAssistantUsage(ctx context.Context, sessionID pgtype.UUID) (int64, error) + GetSessionCacheStats(ctx context.Context, sessionID pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) + GetSessionUsedSkills(ctx context.Context, sessionID pgtype.UUID) ([]string, error) + GetTokenUsageByDayAndType(ctx context.Context, arg dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) + GetTokenUsageByModel(ctx context.Context, arg dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) +} + +// AccessEvaluator checks whether the current channel context may trigger chat. +type AccessEvaluator interface { + Evaluate(ctx context.Context, req acl.EvaluateRequest) (bool, error) +} diff --git a/internal/command/mcp.go b/internal/command/mcp.go index dad32c40..40d56a03 100644 --- a/internal/command/mcp.go +++ b/internal/command/mcp.go @@ -27,7 +27,7 @@ func (h *Handler) buildMCPGroup() *CommandGroup { {"Status", item.Status}, }) } - return formatItems(records), nil + return formatLimitedItems(records, defaultListLimit, "Use /mcp get for full details."), nil }, }) g.Register(SubCommand{ diff --git a/internal/command/memory.go b/internal/command/memory.go index fce3d0a1..4965b685 100644 --- a/internal/command/memory.go +++ b/internal/command/memory.go @@ -13,6 +13,9 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { Name: "list", Usage: "list - List all memory providers", Handler: func(cc CommandContext) (string, error) { + if h.memProvService == nil { + return "Memory provider service is not available.", nil + } items, err := h.memProvService.List(cc.Ctx) if err != nil { return "", err @@ -20,18 +23,44 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { if len(items) == 0 { return "No memory providers found.", nil } - records := make([][]kv, 0, len(items)) + settingsResp, _ := h.getBotSettings(cc) + currentRecords := make([][]kv, 0, 1) + otherRecords := make([][]kv, 0, len(items)) for _, item := range items { def := "" if item.IsDefault { def = " (default)" } - records = append(records, []kv{ - {"Name", item.Name + def}, + label := item.Name + def + record := []kv{ + {"Name", label}, {"Provider", item.Provider}, - }) + } + if item.ID == settingsResp.MemoryProviderID { + label += " [current]" + record[0].value = label + currentRecords = append(currentRecords, record) + continue + } + otherRecords = append(otherRecords, record) } - return formatItems(records), nil + currentRecords = append(currentRecords, otherRecords...) + records := currentRecords + return formatLimitedItems(records, defaultListLimit, "Use /memory current to inspect the active provider."), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show the current memory provider", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{{"Memory Provider", h.resolveMemoryProviderName(cc, settingsResp.MemoryProviderID)}}), nil }, }) g.Register(SubCommand{ @@ -42,7 +71,11 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { if len(cc.Args) < 1 { return "Usage: /memory set ", nil } + if h.settingsService == nil { + return "Settings service is not available.", nil + } name := cc.Args[0] + before, _ := h.getBotSettings(cc) items, err := h.memProvService.List(cc.Ctx) if err != nil { return "", err @@ -55,7 +88,7 @@ func (h *Handler) buildMemoryGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Memory provider set to %q.", item.Name), nil + return formatChangedValue("Memory provider", h.resolveMemoryProviderName(cc, before.MemoryProviderID), item.Name), nil } } return fmt.Sprintf("Memory provider %q not found.", name), nil diff --git a/internal/command/model.go b/internal/command/model.go index 53bcc644..856ed721 100644 --- a/internal/command/model.go +++ b/internal/command/model.go @@ -1,7 +1,9 @@ package command import ( + "errors" "fmt" + "sort" "strings" "github.com/memohai/memoh/internal/models" @@ -12,36 +14,81 @@ func (h *Handler) buildModelGroup() *CommandGroup { g := newCommandGroup("model", "Manage bot models") g.Register(SubCommand{ Name: "list", - Usage: "list - List all available chat models", + Usage: "list [provider_name] - List available chat models", Handler: func(cc CommandContext) (string, error) { + if h.modelsService == nil { + return "Model service is not available.", nil + } items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat) if err != nil { return "", err } + filterProvider := "" + if len(cc.Args) > 0 { + filterProvider = strings.TrimSpace(strings.Join(cc.Args, " ")) + } + items = h.filterModelsByProvider(cc, items, filterProvider) if len(items) == 0 { + if filterProvider != "" { + return fmt.Sprintf("No chat models found for provider %q.", filterProvider), nil + } return "No chat models found.", nil } + settingsResp, _ := h.getBotSettings(cc) + sort.SliceStable(items, func(i, j int) bool { + return modelSortRank(items[i], settingsResp) < modelSortRank(items[j], settingsResp) + }) records := make([][]kv, 0, len(items)) for _, item := range items { provName := h.resolveProviderName(cc, item.ProviderID) + label := item.Name + markers := modelMarkers(item.ID, settingsResp) + if len(markers) > 0 { + label += " [" + strings.Join(markers, ", ") + "]" + } records = append(records, []kv{ - {"Model", item.Name}, + {"Model", label}, {"Provider", provName}, {"Model ID", item.ModelID}, }) } - return formatItems(records), nil + hint := "Use /model current to inspect active selections." + if filterProvider == "" { + hint = "Use /model list to narrow results." + } + return formatLimitedItems(records, defaultListLimit, hint), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show current chat and heartbeat models", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{ + {"Chat Model", h.resolveModelName(cc, settingsResp.ChatModelID)}, + {"Heartbeat Model", h.resolveModelName(cc, settingsResp.HeartbeatModelID)}, + }), nil }, }) g.Register(SubCommand{ Name: "set", - Usage: "set - Set the chat model", + Usage: "set | - Set the chat model", IsWrite: true, Handler: func(cc CommandContext) (string, error) { - if len(cc.Args) < 2 { - return "Usage: /model set ", nil + if len(cc.Args) < 1 { + return "Usage: /model set | ", nil } - modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1]) + if h.settingsService == nil { + return "Settings service is not available.", nil + } + before, _ := h.getBotSettings(cc) + modelResp, err := h.findModelForSelection(cc, cc.Args) if err != nil { return "", err } @@ -51,18 +98,22 @@ func (h *Handler) buildModelGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Chat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil + return formatChangedValue("Chat model", h.resolveModelName(cc, before.ChatModelID), h.resolveModelName(cc, modelResp.ID)), nil }, }) g.Register(SubCommand{ Name: "set-heartbeat", - Usage: "set-heartbeat - Set the heartbeat model", + Usage: "set-heartbeat | - Set the heartbeat model", IsWrite: true, Handler: func(cc CommandContext) (string, error) { - if len(cc.Args) < 2 { - return "Usage: /model set-heartbeat ", nil + if len(cc.Args) < 1 { + return "Usage: /model set-heartbeat | ", nil } - modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1]) + if h.settingsService == nil { + return "Settings service is not available.", nil + } + before, _ := h.getBotSettings(cc) + modelResp, err := h.findModelForSelection(cc, cc.Args) if err != nil { return "", err } @@ -72,7 +123,7 @@ func (h *Handler) buildModelGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Heartbeat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil + return formatChangedValue("Heartbeat model", h.resolveModelName(cc, before.HeartbeatModelID), h.resolveModelName(cc, modelResp.ID)), nil }, }) return g @@ -105,3 +156,86 @@ func (h *Handler) findModelByProviderAndName(cc CommandContext, providerName, mo } return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName) } + +func (h *Handler) findModelForSelection(cc CommandContext, args []string) (models.GetResponse, error) { + if h.modelsService == nil { + return models.GetResponse{}, errors.New("model service is not available") + } + if len(args) == 0 { + return models.GetResponse{}, errors.New("model identifier is required") + } + if len(args) == 1 { + return h.findModelByIDOrName(cc, args[0]) + } + return h.findModelByProviderAndName(cc, args[0], strings.Join(args[1:], " ")) +} + +func (h *Handler) findModelByIDOrName(cc CommandContext, target string) (models.GetResponse, error) { + items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat) + if err != nil { + return models.GetResponse{}, err + } + target = strings.TrimSpace(target) + if target == "" { + return models.GetResponse{}, errors.New("model identifier is required") + } + for _, item := range items { + if strings.EqualFold(item.ModelID, target) { + return item, nil + } + } + matches := make([]models.GetResponse, 0, 4) + for _, item := range items { + if strings.EqualFold(item.Name, target) { + matches = append(matches, item) + } + } + switch len(matches) { + case 0: + return models.GetResponse{}, fmt.Errorf("model %q not found", target) + case 1: + return matches[0], nil + default: + choices := make([]string, 0, len(matches)) + for _, item := range matches { + choices = append(choices, fmt.Sprintf("%s/%s", h.resolveProviderName(cc, item.ProviderID), item.ModelID)) + } + return models.GetResponse{}, fmt.Errorf("model %q is ambiguous; use a model ID or provider-qualified name (%s)", target, strings.Join(choices, ", ")) + } +} + +func (h *Handler) filterModelsByProvider(cc CommandContext, items []models.GetResponse, providerName string) []models.GetResponse { + providerName = strings.TrimSpace(providerName) + if providerName == "" { + return items + } + filtered := make([]models.GetResponse, 0, len(items)) + for _, item := range items { + if strings.EqualFold(h.resolveProviderName(cc, item.ProviderID), providerName) { + filtered = append(filtered, item) + } + } + return filtered +} + +func modelMarkers(modelID string, settingsResp settings.Settings) []string { + var markers []string + if modelID == settingsResp.ChatModelID { + markers = append(markers, "chat") + } + if modelID == settingsResp.HeartbeatModelID { + markers = append(markers, "heartbeat") + } + return markers +} + +func modelSortRank(model models.GetResponse, settingsResp settings.Settings) int { + switch len(modelMarkers(model.ID, settingsResp)) { + case 2: + return 0 + case 1: + return 1 + default: + return 2 + } +} diff --git a/internal/command/parser_test.go b/internal/command/parser_test.go index 3956c1b3..5a16ea06 100644 --- a/internal/command/parser_test.go +++ b/internal/command/parser_test.go @@ -13,6 +13,8 @@ func TestParse_Basic(t *testing.T) { args []string }{ {"/help", "help", "", nil}, + {"/help model", "help", "model", nil}, + {"/help model set", "help", "model", []string{"set"}}, {"/subagent list", "subagent", "list", nil}, {"/subagent get mybot", "subagent", "get", []string{"mybot"}}, {"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}}, diff --git a/internal/command/registry.go b/internal/command/registry.go index 80de168a..d6d10d53 100644 --- a/internal/command/registry.go +++ b/internal/command/registry.go @@ -8,10 +8,18 @@ import ( // CommandContext carries execution context for a sub-command. type CommandContext struct { - Ctx context.Context - BotID string - Role string // "owner", "admin", "member", or "" (guest) - Args []string + Ctx context.Context + BotID string + Role string // "owner", "admin", "member", or "" (guest) + Args []string + ChannelIdentityID string + UserID string + ChannelType string + ConversationType string + ConversationID string + ThreadID string + RouteID string + SessionID string } // SubCommand describes a single sub-command within a resource group. @@ -50,15 +58,38 @@ func (g *CommandGroup) Usage() string { fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description) for _, name := range g.order { sub := g.commands[name] - perm := "" + desc := subSummary(sub) if sub.IsWrite { - perm = " [owner]" + desc += " [owner]" } - fmt.Fprintf(&b, "- %s%s\n", sub.Usage, perm) + fmt.Fprintf(&b, "- %s\n", desc) } + fmt.Fprintf(&b, "\nUse /help %s for details.", g.Name) return b.String() } +func (g *CommandGroup) ActionHelp(action string) string { + sub, ok := g.commands[action] + if !ok { + return fmt.Sprintf("Unknown action %q for /%s.\n\n%s", action, g.Name, g.Usage()) + } + usage, summary := splitUsage(sub.Usage) + var b strings.Builder + fmt.Fprintf(&b, "/%s %s\n", g.Name, sub.Name) + if summary != "" { + fmt.Fprintf(&b, "- Summary: %s\n", summary) + } + if usage == "" { + usage = sub.Name + } + fmt.Fprintf(&b, "- Usage: /%s %s\n", g.Name, usage) + if sub.IsWrite { + b.WriteString("- Access: owner only\n") + } + fmt.Fprintf(&b, "- Tip: use /help %s to view sibling actions.", g.Name) + return strings.TrimRight(b.String(), "\n") +} + // Registry holds all registered command groups. type Registry struct { groups map[string]*CommandGroup @@ -83,12 +114,47 @@ func (r *Registry) GlobalHelp() string { b.WriteString("/help - Show this help message\n") b.WriteString("/new - Start a new conversation (resets session context)\n") b.WriteString("/stop - Stop the current generation\n\n") - for i, name := range r.order { - if i > 0 { - b.WriteByte('\n') - } + for _, name := range r.order { group := r.groups[name] - b.WriteString(group.Usage()) + fmt.Fprintf(&b, "- /%s - %s\n", group.Name, group.Description) } + b.WriteString("\nUse /help to view actions, e.g. /help model") return strings.TrimRight(b.String(), "\n") } + +func (r *Registry) GroupHelp(name string) string { + group, ok := r.groups[name] + if !ok { + return fmt.Sprintf("Unknown command group: /%s\n\n%s", name, r.GlobalHelp()) + } + return group.Usage() +} + +func (r *Registry) ActionHelp(groupName, action string) string { + group, ok := r.groups[groupName] + if !ok { + return fmt.Sprintf("Unknown command group: /%s\n\n%s", groupName, r.GlobalHelp()) + } + return group.ActionHelp(action) +} + +func splitUsage(usage string) (commandUsage string, summary string) { + usage = strings.TrimSpace(usage) + if usage == "" { + return "", "" + } + parts := strings.SplitN(usage, " - ", 2) + commandUsage = strings.TrimSpace(parts[0]) + if len(parts) > 1 { + summary = strings.TrimSpace(parts[1]) + } + return commandUsage, summary +} + +func subSummary(sub SubCommand) string { + usage, summary := splitUsage(sub.Usage) + if summary == "" { + return usage + } + return fmt.Sprintf("%s - %s", sub.Name, summary) +} diff --git a/internal/command/search.go b/internal/command/search.go index a18432c5..19ad41f3 100644 --- a/internal/command/search.go +++ b/internal/command/search.go @@ -13,6 +13,9 @@ func (h *Handler) buildSearchGroup() *CommandGroup { Name: "list", Usage: "list - List all search providers", Handler: func(cc CommandContext) (string, error) { + if h.searchProvService == nil { + return "Search provider service is not available.", nil + } items, err := h.searchProvService.List(cc.Ctx, "") if err != nil { return "", err @@ -20,14 +23,40 @@ func (h *Handler) buildSearchGroup() *CommandGroup { if len(items) == 0 { return "No search providers found.", nil } - records := make([][]kv, 0, len(items)) + settingsResp, _ := h.getBotSettings(cc) + currentRecords := make([][]kv, 0, 1) + otherRecords := make([][]kv, 0, len(items)) for _, item := range items { - records = append(records, []kv{ - {"Name", item.Name}, + label := item.Name + record := []kv{ + {"Name", label}, {"Provider", item.Provider}, - }) + } + if item.ID == settingsResp.SearchProviderID { + label += " [current]" + record[0].value = label + currentRecords = append(currentRecords, record) + continue + } + otherRecords = append(otherRecords, record) } - return formatItems(records), nil + currentRecords = append(currentRecords, otherRecords...) + records := currentRecords + return formatLimitedItems(records, defaultListLimit, "Use /search current to inspect the active provider."), nil + }, + }) + g.Register(SubCommand{ + Name: "current", + Usage: "current - Show the current search provider", + Handler: func(cc CommandContext) (string, error) { + if h.settingsService == nil { + return "Settings service is not available.", nil + } + settingsResp, err := h.getBotSettings(cc) + if err != nil { + return "", err + } + return formatKV([]kv{{"Search Provider", h.resolveSearchProviderName(cc, settingsResp.SearchProviderID)}}), nil }, }) g.Register(SubCommand{ @@ -38,7 +67,11 @@ func (h *Handler) buildSearchGroup() *CommandGroup { if len(cc.Args) < 1 { return "Usage: /search set ", nil } + if h.settingsService == nil { + return "Settings service is not available.", nil + } name := cc.Args[0] + before, _ := h.getBotSettings(cc) items, err := h.searchProvService.List(cc.Ctx, "") if err != nil { return "", err @@ -51,7 +84,7 @@ func (h *Handler) buildSearchGroup() *CommandGroup { if err != nil { return "", err } - return fmt.Sprintf("Search provider set to %q.", item.Name), nil + return formatChangedValue("Search provider", h.resolveSearchProviderName(cc, before.SearchProviderID), item.Name), nil } } return fmt.Sprintf("Search provider %q not found.", name), nil diff --git a/internal/command/settings.go b/internal/command/settings.go index b88d9be5..f8ad64d4 100644 --- a/internal/command/settings.go +++ b/internal/command/settings.go @@ -1,6 +1,7 @@ package command import ( + "errors" "fmt" "strconv" "strings" @@ -106,6 +107,13 @@ func settingsUpdateUsage() string { "- --heartbeat_model_id " } +func (h *Handler) getBotSettings(cc CommandContext) (settings.Settings, error) { + if h.settingsService == nil { + return settings.Settings{}, errors.New("settings service is not available") + } + return h.settingsService.GetBot(cc.Ctx, cc.BotID) +} + // resolveModelName resolves a model UUID to "model_name (provider_name)". func (h *Handler) resolveModelName(cc CommandContext, modelID string) string { if modelID == "" { diff --git a/internal/command/status.go b/internal/command/status.go index b68e1a84..1d19774c 100644 --- a/internal/command/status.go +++ b/internal/command/status.go @@ -6,7 +6,9 @@ import ( "strconv" "strings" + "github.com/google/uuid" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" ) func (h *Handler) buildStatusGroup() *CommandGroup { @@ -16,78 +18,113 @@ func (h *Handler) buildStatusGroup() *CommandGroup { Name: "show", Usage: "show - Show current session status", Handler: func(cc CommandContext) (string, error) { + if strings.TrimSpace(cc.SessionID) == "" { + return "No active session found for this conversation.", nil + } + return h.renderSessionStatus(cc, cc.SessionID, "current conversation") + }, + }) + g.Register(SubCommand{ + Name: "latest", + Usage: "latest - Show the latest session status for this bot", + Handler: func(cc CommandContext) (string, error) { + if h.queries == nil { + return "Session info is not available.", nil + } botUUID, err := parseBotUUID(cc.BotID) if err != nil { return "", err } - sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return "No active session found.", nil + return "No session found for this bot.", nil } return "", err } - - msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, sessionID) - if err != nil { - return "", fmt.Errorf("count messages: %w", err) - } - - var usedTokens int64 - latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, sessionID) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return "", fmt.Errorf("get usage: %w", err) - } - if err == nil { - usedTokens = latestUsage - } - - cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, sessionID) - if err != nil { - return "", fmt.Errorf("get cache: %w", err) - } - - var contextWindowStr string - if h.settingsService != nil { - s, sErr := h.settingsService.GetBot(cc.Ctx, cc.BotID) - if sErr == nil && s.ChatModelID != "" && h.modelsService != nil { - m, mErr := h.modelsService.GetByID(cc.Ctx, s.ChatModelID) - if mErr == nil && m.Config.ContextWindow != nil { - contextWindowStr = formatTokens(int64(*m.Config.ContextWindow)) - } - } - } - - var cacheHitRate float64 - if cacheRow.TotalInputTokens > 0 { - cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100 - } - - skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, sessionID) - - var b strings.Builder - b.WriteString("Session Status:\n\n") - fmt.Fprintf(&b, "- Messages: %d\n", msgCount) - if contextWindowStr != "" { - fmt.Fprintf(&b, "- Context: %s / %s\n", formatTokens(usedTokens), contextWindowStr) - } else { - fmt.Fprintf(&b, "- Context: %s\n", formatTokens(usedTokens)) - } - fmt.Fprintf(&b, "- Cache Hit Rate: %.1f%%\n", cacheHitRate) - fmt.Fprintf(&b, "- Cache Read: %s\n", formatTokens(cacheRow.CacheReadTokens)) - fmt.Fprintf(&b, "- Cache Write: %s\n", formatTokens(cacheRow.CacheWriteTokens)) - - if len(skills) > 0 { - fmt.Fprintf(&b, "- Skills: %s\n", strings.Join(skills, ", ")) - } - - return strings.TrimRight(b.String(), "\n"), nil + return h.renderSessionStatus(cc, sessionID.String(), "latest bot session") }, }) return g } +func (h *Handler) renderSessionStatus(cc CommandContext, sessionID string, scope string) (string, error) { + if h.queries == nil { + return "Session info is not available.", nil + } + pgSessionID, err := parseCommandUUID(sessionID) + if err != nil { + return "", err + } + msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, pgSessionID) + if err != nil { + return "", fmt.Errorf("count messages: %w", err) + } + + var usedTokens int64 + latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, pgSessionID) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return "", fmt.Errorf("get usage: %w", err) + } + if err == nil { + usedTokens = latestUsage + } + + cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, pgSessionID) + if err != nil { + return "", fmt.Errorf("get cache: %w", err) + } + + var cacheHitRate float64 + if cacheRow.TotalInputTokens > 0 { + cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100 + } + + skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, pgSessionID) + + contextUsage := formatTokens(usedTokens) + if contextWindow := h.resolveContextWindow(cc); contextWindow != "" { + contextUsage = contextUsage + " / " + contextWindow + } + + pairs := []kv{ + {"Scope", scope}, + {"Session ID", sessionID}, + {"Messages", strconv.FormatInt(msgCount, 10)}, + {"Context", contextUsage}, + {"Cache Hit Rate", fmt.Sprintf("%.1f%%", cacheHitRate)}, + {"Cache Read", formatTokens(cacheRow.CacheReadTokens)}, + {"Cache Write", formatTokens(cacheRow.CacheWriteTokens)}, + } + if len(skills) > 0 { + pairs = append(pairs, kv{"Skills", strings.Join(skills, ", ")}) + } + return formatKV(pairs), nil +} + +func (h *Handler) resolveContextWindow(cc CommandContext) string { + if h.settingsService == nil || h.modelsService == nil { + return "" + } + s, err := h.settingsService.GetBot(cc.Ctx, cc.BotID) + if err != nil || s.ChatModelID == "" { + return "" + } + m, err := h.modelsService.GetByID(cc.Ctx, s.ChatModelID) + if err != nil || m.Config.ContextWindow == nil { + return "" + } + return formatTokens(int64(*m.Config.ContextWindow)) +} + +func parseCommandUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(strings.TrimSpace(id)) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid uuid: %w", err) + } + return pgtype.UUID{Bytes: parsed, Valid: true}, nil +} + func formatTokens(n int64) string { if n >= 1_000_000 { return fmt.Sprintf("%.1fM", float64(n)/1_000_000) diff --git a/internal/command/usage.go b/internal/command/usage.go index a08792a8..4ed5b40a 100644 --- a/internal/command/usage.go +++ b/internal/command/usage.go @@ -18,6 +18,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup { Name: "summary", Usage: "summary - Token usage summary (last 7 days)", Handler: func(cc CommandContext) (string, error) { + if h.queries == nil { + return "Usage info is not available.", nil + } botUUID, err := parseBotUUID(cc.BotID) if err != nil { return "", err @@ -89,6 +92,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup { Name: "by-model", Usage: "by-model - Token usage grouped by model", Handler: func(cc CommandContext) (string, error) { + if h.queries == nil { + return "Usage info is not available.", nil + } botUUID, err := parseBotUUID(cc.BotID) if err != nil { return "", err