feat(command): improve slash command UX

Make slash commands easier to navigate in chat by splitting help into levels, compacting list output, and surfacing current selections for model, search, memory, and browser settings. Also route /status to the active conversation session and add an access inspector so users can understand their current command and ACL context.
This commit is contained in:
Acbox
2026-04-12 17:25:10 +08:00
parent 3307b27a80
commit 0549f5cafc
22 changed files with 1080 additions and 138 deletions
+9
View File
@@ -604,6 +604,7 @@ func provideChannelRouter(
emailOutboxService,
heartbeatService,
queries,
aclService,
&commandSkillLoaderAdapter{handler: containerdHandler},
&commandContainerFSAdapter{manager: manager},
))
@@ -779,6 +780,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
}
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
if err != nil {
return inbound.SessionResult{}, err
}
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
}
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
if err != nil {
+9
View File
@@ -492,6 +492,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc
emailOutboxService,
heartbeatService,
queries,
aclService,
&commandSkillLoaderAdapter{handler: containerdHandler},
&commandContainerFSAdapter{manager: manager},
))
@@ -877,6 +878,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
}
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
if err != nil {
return inbound.SessionResult{}, err
}
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
}
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
if err != nil {
+106 -1
View File
@@ -71,6 +71,7 @@ type ttsModelResolver interface {
// SessionEnsurer resolves or creates an active session for a route.
type SessionEnsurer interface {
EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error)
GetActiveSession(ctx context.Context, routeID string) (SessionResult, error)
// CreateNewSession always creates a fresh session and sets it as the
// active session for the given route, replacing any previous one.
// sessionType defaults to "chat" if empty.
@@ -298,11 +299,23 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
if isStopCommand(cmdText) && isDirectedAtBot(msg) {
return p.handleStopCommand(ctx, cfg, msg, sender, identity)
}
if isStatusCommand(cmdText) && isDirectedAtBot(msg) {
return p.handleStatusCommand(ctx, cfg, msg, sender, identity)
}
// Skip generic command handler for mode-prefix commands (/btw, /now, /next)
// so they pass through to mode detection below.
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText)
reply, err := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
BotID: strings.TrimSpace(identity.BotID),
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
UserID: strings.TrimSpace(identity.UserID),
Text: cmdText,
ChannelType: msg.Channel.String(),
ConversationType: strings.TrimSpace(msg.Conversation.Type),
ConversationID: strings.TrimSpace(msg.Conversation.ID),
ThreadID: extractThreadID(msg),
})
if err != nil {
reply = "Error: " + err.Error()
}
@@ -2489,6 +2502,18 @@ func isNewSessionCommand(cmdText string) bool {
return parsed.Resource == "new"
}
func isStatusCommand(cmdText string) bool {
extracted := command.ExtractCommandText(cmdText)
if extracted == "" {
return false
}
parsed, err := command.Parse(extracted)
if err != nil {
return false
}
return parsed.Resource == "status"
}
// resolveNewSessionType determines the session type for /new command.
// /new chat → chat, /new discuss → discuss, /new (no arg) → default by context.
// WebUI (local channel) always defaults to chat.
@@ -2611,3 +2636,83 @@ func (p *ChannelInboundProcessor) handleNewSessionCommand(
Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)},
})
}
func (p *ChannelInboundProcessor) handleStatusCommand(
ctx context.Context,
cfg channel.ChannelConfig,
msg channel.InboundMessage,
sender channel.StreamReplySender,
identity InboundIdentity,
) error {
target := strings.TrimSpace(msg.ReplyTarget)
if target == "" {
return errors.New("reply target missing for /status command")
}
if p.routeResolver == nil {
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: "Error: route resolver not configured."},
})
}
if p.commandHandler == nil {
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: "Error: command handler not configured."},
})
}
threadID := extractThreadID(msg)
routeMetadata := buildRouteMetadata(msg, identity)
p.enrichConversationAvatar(ctx, cfg, msg, routeMetadata)
resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{
BotID: identity.BotID,
Platform: msg.Channel.String(),
ConversationID: msg.Conversation.ID,
ThreadID: threadID,
ConversationType: msg.Conversation.Type,
ChannelIdentityID: identity.UserID,
ChannelConfigID: identity.ChannelConfigID,
ReplyTarget: target,
Metadata: routeMetadata,
})
if err != nil {
if p.logger != nil {
p.logger.Warn("resolve route for /status command failed", slog.Any("error", err))
}
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: "Error: failed to resolve conversation route."},
})
}
sessionID := ""
if p.sessionEnsurer != nil {
sess, sessErr := p.sessionEnsurer.GetActiveSession(ctx, resolved.RouteID)
if sessErr == nil {
sessionID = strings.TrimSpace(sess.ID)
} else if p.logger != nil {
p.logger.Debug("resolve active session for /status command failed", slog.Any("error", sessErr))
}
}
reply, execErr := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
BotID: strings.TrimSpace(identity.BotID),
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
UserID: strings.TrimSpace(identity.UserID),
Text: rawTextForCommand(msg, strings.TrimSpace(msg.Message.PlainText())),
ChannelType: msg.Channel.String(),
ConversationType: strings.TrimSpace(msg.Conversation.Type),
ConversationID: strings.TrimSpace(msg.Conversation.ID),
ThreadID: threadID,
RouteID: strings.TrimSpace(resolved.RouteID),
SessionID: sessionID,
})
if execErr != nil {
reply = "Error: " + execErr.Error()
}
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: reply},
})
}
+135
View File
@@ -13,11 +13,15 @@ import (
"strings"
"testing"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/acl"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/identities"
"github.com/memohai/memoh/internal/channel/route"
"github.com/memohai/memoh/internal/command"
"github.com/memohai/memoh/internal/conversation"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/media"
messagepkg "github.com/memohai/memoh/internal/message"
"github.com/memohai/memoh/internal/schedule"
@@ -185,6 +189,71 @@ type fakeChatACL struct {
lastReq acl.EvaluateRequest
}
type fakeSessionEnsurer struct {
activeSession SessionResult
activeErr error
lastRouteID string
}
func (f *fakeSessionEnsurer) EnsureActiveSession(_ context.Context, _, routeID, _ string) (SessionResult, error) {
f.lastRouteID = routeID
if f.activeErr != nil {
return SessionResult{}, f.activeErr
}
return f.activeSession, nil
}
func (f *fakeSessionEnsurer) GetActiveSession(_ context.Context, routeID string) (SessionResult, error) {
f.lastRouteID = routeID
if f.activeErr != nil {
return SessionResult{}, f.activeErr
}
return f.activeSession, nil
}
func (f *fakeSessionEnsurer) CreateNewSession(_ context.Context, _, routeID, _, _ string) (SessionResult, error) {
f.lastRouteID = routeID
if f.activeErr != nil {
return SessionResult{}, f.activeErr
}
return f.activeSession, nil
}
type fakeCommandQueries struct {
messageCount int64
usage int64
cacheRow dbsqlc.GetSessionCacheStatsRow
skills []string
}
func (*fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
return pgtype.UUID{}, errors.New("unexpected latest session lookup")
}
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
return f.messageCount, nil
}
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
return f.usage, nil
}
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
return f.cacheRow, nil
}
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
return f.skills, nil
}
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
return nil, nil
}
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
return nil, nil
}
func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) {
f.calls++
f.lastReq = req
@@ -548,6 +617,72 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) {
}
}
func TestChannelInboundProcessorStatusUsesRouteSession(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-status"}}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-status", RouteID: "route-status"}}
gateway := &fakeChatGateway{}
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
processor.SetSessionEnsurer(&fakeSessionEnsurer{
activeSession: SessionResult{ID: "11111111-1111-1111-1111-111111111111", Type: "chat"},
})
processor.SetCommandHandler(command.NewHandler(
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
&fakeCommandQueries{
messageCount: 9,
usage: 512,
cacheRow: dbsqlc.GetSessionCacheStatsRow{
CacheReadTokens: 64,
CacheWriteTokens: 32,
TotalInputTokens: 512,
},
skills: []string{"search"},
},
nil,
nil,
nil,
))
sender := &fakeReplySender{}
cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("discord")}
msg := channel.InboundMessage{
BotID: "bot-1",
Channel: channel.ChannelType("discord"),
Message: channel.Message{Text: "/status"},
ReplyTarget: "discord:status",
Sender: channel.Identity{SubjectID: "user-1"},
Conversation: channel.Conversation{
ID: "conv-status",
Type: channel.ConversationTypePrivate,
},
}
if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one status reply, got %d", len(sender.sent))
}
if !strings.Contains(sender.sent[0].Message.Text, "- Scope: current conversation") {
t.Fatalf("expected current conversation scope, got %q", sender.sent[0].Message.Text)
}
if !strings.Contains(sender.sent[0].Message.Text, "- Session ID: 11111111-1111-1111-1111-111111111111") {
t.Fatalf("expected active route session in reply, got %q", sender.sent[0].Message.Text)
}
}
func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) {
t.Parallel()
+78
View File
@@ -0,0 +1,78 @@
package command
import (
"fmt"
"strings"
"github.com/memohai/memoh/internal/acl"
)
func (h *Handler) buildAccessGroup() *CommandGroup {
g := newCommandGroup("access", "Inspect identity and permission context")
g.DefaultAction = "show"
g.Register(SubCommand{
Name: "show",
Usage: "show - Show current identity, write access, and chat ACL context",
Handler: func(cc CommandContext) (string, error) {
writeAccess := "no"
if cc.Role == "owner" {
writeAccess = "yes"
}
pairs := []kv{
{"Channel Identity", fallbackValue(cc.ChannelIdentityID)},
{"Linked User", fallbackValue(cc.UserID)},
{"Bot Role", fallbackValue(cc.Role)},
{"Write Commands", writeAccess},
{"Channel", fallbackValue(cc.ChannelType)},
{"Conversation Type", fallbackValue(cc.ConversationType)},
{"Conversation ID", fallbackValue(cc.ConversationID)},
{"Thread ID", fallbackValue(cc.ThreadID)},
}
if strings.TrimSpace(cc.RouteID) != "" {
pairs = append(pairs, kv{"Route ID", cc.RouteID})
}
if strings.TrimSpace(cc.SessionID) != "" {
pairs = append(pairs, kv{"Session ID", cc.SessionID})
}
aclStatus := "unavailable"
if h.aclEvaluator != nil && strings.TrimSpace(cc.ChannelType) != "" {
allowed, err := h.aclEvaluator.Evaluate(cc.Ctx, acl.EvaluateRequest{
BotID: cc.BotID,
ChannelIdentityID: cc.ChannelIdentityID,
ChannelType: cc.ChannelType,
SourceScope: acl.SourceScope{
ConversationType: cc.ConversationType,
ConversationID: cc.ConversationID,
ThreadID: cc.ThreadID,
},
})
switch {
case err != nil:
aclStatus = "error: " + err.Error()
case allowed:
aclStatus = "allow"
default:
aclStatus = "deny"
}
}
pairs = append(pairs, kv{"Chat ACL", aclStatus})
return formatKV(pairs), nil
},
})
return g
}
func fallbackValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return "(none)"
}
return value
}
func formatChangedValue(label, before, after string) string {
return fmt.Sprintf("%s: %s -> %s", label, fallbackValue(before), fallbackValue(after))
}
+37 -6
View File
@@ -13,6 +13,9 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
Name: "list",
Usage: "list - List all browser contexts",
Handler: func(cc CommandContext) (string, error) {
if h.browserCtxService == nil {
return "Browser context service is not available.", nil
}
items, err := h.browserCtxService.List(cc.Ctx)
if err != nil {
return "", err
@@ -20,13 +23,37 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
if len(items) == 0 {
return "No browser contexts found.", nil
}
records := make([][]kv, 0, len(items))
settingsResp, _ := h.getBotSettings(cc)
currentRecords := make([][]kv, 0, 1)
otherRecords := make([][]kv, 0, len(items))
for _, item := range items {
records = append(records, []kv{
{"Name", item.Name},
})
label := item.Name
record := []kv{{"Name", label}}
if item.ID == settingsResp.BrowserContextID {
label += " [current]"
record[0].value = label
currentRecords = append(currentRecords, record)
continue
}
otherRecords = append(otherRecords, record)
}
return formatItems(records), nil
currentRecords = append(currentRecords, otherRecords...)
records := currentRecords
return formatLimitedItems(records, defaultListLimit, "Use /browser current to inspect the active context."), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show the current browser context",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{{"Browser Context", h.resolveBrowserContextName(cc, settingsResp.BrowserContextID)}}), nil
},
})
g.Register(SubCommand{
@@ -37,7 +64,11 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
if len(cc.Args) < 1 {
return "Usage: /browser set <name>", nil
}
if h.settingsService == nil {
return "Settings service is not available.", nil
}
name := cc.Args[0]
before, _ := h.getBotSettings(cc)
items, err := h.browserCtxService.List(cc.Ctx)
if err != nil {
return "", err
@@ -50,7 +81,7 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
if err != nil {
return "", err
}
return fmt.Sprintf("Browser context set to %q.", item.Name), nil
return formatChangedValue("Browser context", h.resolveBrowserContextName(cc, before.BrowserContextID), item.Name), nil
}
}
return fmt.Sprintf("Browser context %q not found.", name), nil
+1
View File
@@ -16,5 +16,6 @@ func (h *Handler) buildRegistry() *Registry {
r.RegisterGroup(h.buildSkillGroup())
r.RegisterGroup(h.buildFSGroup())
r.RegisterGroup(h.buildStatusGroup())
r.RegisterGroup(h.buildAccessGroup())
return r
}
+3 -3
View File
@@ -24,7 +24,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
{"Provider", item.Provider},
})
}
return formatItems(records), nil
return formatLimitedItems(records, defaultListLimit, "Use /email bindings to inspect bot bindings."), nil
},
})
g.Register(SubCommand{
@@ -46,7 +46,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
{"Permissions", perms},
})
}
return formatItems(records), nil
return formatLimitedItems(records, defaultListLimit, "Use /email outbox to inspect recent sends."), nil
},
})
g.Register(SubCommand{
@@ -70,7 +70,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
{"Sent", item.SentAt.Format("01-02 15:04")},
})
}
return formatItems(records), nil
return formatLimitedItems(records, 10, "Use the Web UI for older outbox entries."), nil
},
})
return g
+32 -14
View File
@@ -6,19 +6,10 @@ import (
"unicode/utf8"
)
const defaultListLimit = 12
// formatItems renders a list of records as a Markdown-style list.
// Each record is a slice of kv pairs; the first pair's value is used as the
// bullet title, and subsequent pairs are indented beneath it.
//
// Example output:
//
// - mybot
// Description: A helpful assistant
// ID: abc123
//
// - another
// Description: Something else
// ID: def456
// Each record is rendered on a single line so long lists stay readable in IM.
func formatItems(items [][]kv) string {
if len(items) == 0 {
return ""
@@ -31,14 +22,41 @@ func formatItems(items [][]kv) string {
if i > 0 {
b.WriteByte('\n')
}
fmt.Fprintf(&b, "- %s\n", record[0].value)
fmt.Fprintf(&b, "- %s", record[0].value)
extras := make([]string, 0, len(record)-1)
for _, pair := range record[1:] {
fmt.Fprintf(&b, " %s: %s\n", pair.key, pair.value)
if strings.TrimSpace(pair.value) == "" {
continue
}
extras = append(extras, fmt.Sprintf("%s: %s", pair.key, pair.value))
}
if len(extras) > 0 {
fmt.Fprintf(&b, " | %s", strings.Join(extras, " | "))
}
}
return b.String()
}
func formatLimitedItems(items [][]kv, limit int, hint string) string {
if len(items) == 0 {
return ""
}
if limit <= 0 {
limit = defaultListLimit
}
total := len(items)
if total <= limit {
return formatItems(items)
}
shown := items[:limit]
result := formatItems(shown)
suffix := fmt.Sprintf("Showing %d of %d items.", len(shown), total)
if strings.TrimSpace(hint) != "" {
suffix += " " + strings.TrimSpace(hint)
}
return result + "\n\n" + suffix
}
// formatKV renders key-value pairs as a simple Markdown list.
//
// Example output:
+57 -13
View File
@@ -4,10 +4,10 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/browsercontexts"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
emailpkg "github.com/memohai/memoh/internal/email"
"github.com/memohai/memoh/internal/heartbeat"
"github.com/memohai/memoh/internal/mcp"
@@ -56,13 +56,28 @@ type Handler struct {
emailService *emailpkg.Service
emailOutboxService *emailpkg.OutboxService
heartbeatService *heartbeat.Service
queries *dbsqlc.Queries
queries CommandQueries
aclEvaluator AccessEvaluator
skillLoader SkillLoader
containerFS ContainerFS
logger *slog.Logger
}
// ExecuteInput carries the caller identity and channel context for command execution.
type ExecuteInput struct {
BotID string
ChannelIdentityID string
UserID string
Text string
ChannelType string
ConversationType string
ConversationID string
ThreadID string
RouteID string
SessionID string
}
// NewHandler creates a Handler with all required services.
func NewHandler(
log *slog.Logger,
@@ -78,7 +93,8 @@ func NewHandler(
emailService *emailpkg.Service,
emailOutboxService *emailpkg.OutboxService,
heartbeatService *heartbeat.Service,
queries *dbsqlc.Queries,
queries CommandQueries,
aclEvaluator AccessEvaluator,
skillLoader SkillLoader,
containerFS ContainerFS,
) *Handler {
@@ -99,6 +115,7 @@ func NewHandler(
emailOutboxService: emailOutboxService,
heartbeatService: heartbeatService,
queries: queries,
aclEvaluator: aclEvaluator,
skillLoader: skillLoader,
containerFS: containerFS,
logger: log.With(slog.String("component", "command")),
@@ -140,7 +157,16 @@ func (h *Handler) IsCommand(text string) bool {
// Execute parses and runs a slash command, returning the text reply.
func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) {
cmdText := ExtractCommandText(text)
return h.ExecuteWithInput(ctx, ExecuteInput{
BotID: botID,
ChannelIdentityID: channelIdentityID,
Text: text,
})
}
// ExecuteWithInput parses and runs a slash command with channel/session context.
func (h *Handler) ExecuteWithInput(ctx context.Context, input ExecuteInput) (string, error) {
cmdText := ExtractCommandText(input.Text)
if cmdText == "" {
return h.registry.GlobalHelp(), nil
}
@@ -151,12 +177,16 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
// Resolve the user's role in this bot.
role := ""
if h.roleResolver != nil && channelIdentityID != "" {
r, err := h.roleResolver.GetMemberRole(ctx, botID, channelIdentityID)
roleIdentityID := input.ChannelIdentityID
if strings.TrimSpace(input.UserID) != "" {
roleIdentityID = strings.TrimSpace(input.UserID)
}
if h.roleResolver != nil && roleIdentityID != "" {
r, err := h.roleResolver.GetMemberRole(ctx, input.BotID, roleIdentityID)
if err != nil {
h.logger.Warn("failed to resolve member role",
slog.String("bot_id", botID),
slog.String("channel_identity_id", channelIdentityID),
slog.String("bot_id", input.BotID),
slog.String("role_identity_id", roleIdentityID),
slog.Any("error", err),
)
} else {
@@ -165,15 +195,29 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
}
cc := CommandContext{
Ctx: ctx,
BotID: botID,
Role: role,
Args: parsed.Args,
Ctx: ctx,
BotID: input.BotID,
Role: role,
Args: parsed.Args,
ChannelIdentityID: strings.TrimSpace(input.ChannelIdentityID),
UserID: strings.TrimSpace(input.UserID),
ChannelType: strings.TrimSpace(input.ChannelType),
ConversationType: strings.TrimSpace(input.ConversationType),
ConversationID: strings.TrimSpace(input.ConversationID),
ThreadID: strings.TrimSpace(input.ThreadID),
RouteID: strings.TrimSpace(input.RouteID),
SessionID: strings.TrimSpace(input.SessionID),
}
// /help
if parsed.Resource == "help" {
return h.registry.GlobalHelp(), nil
if parsed.Action == "" {
return h.registry.GlobalHelp(), nil
}
if len(parsed.Args) == 0 {
return h.registry.GroupHelp(parsed.Action), nil
}
return h.registry.ActionHelp(parsed.Action, parsed.Args[0]), nil
}
// Top-level commands (e.g. /new) are handled by the channel inbound
+173 -4
View File
@@ -5,6 +5,10 @@ import (
"strings"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/mcp"
"github.com/memohai/memoh/internal/schedule"
"github.com/memohai/memoh/internal/settings"
@@ -25,9 +29,58 @@ type fakeScheduleService struct {
items []schedule.Schedule
}
type fakeCommandQueries struct {
latestSessionID pgtype.UUID
latestSessionErr error
messageCount int64
latestUsage int64
latestUsageErr error
cacheRow dbsqlc.GetSessionCacheStatsRow
cacheErr error
skills []string
}
func (f *fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
return f.latestSessionID, f.latestSessionErr
}
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
return f.messageCount, nil
}
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
if f.latestUsageErr != nil {
return 0, f.latestUsageErr
}
return f.latestUsage, nil
}
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
if f.cacheErr != nil {
return dbsqlc.GetSessionCacheStatsRow{}, f.cacheErr
}
return f.cacheRow, nil
}
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
return f.skills, nil
}
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
return nil, nil
}
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
return nil, nil
}
// newTestHandler creates a Handler with nil services for use in tests.
func newTestHandler(roleResolver MemberRoleResolver) *Handler {
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
}
func newTestHandlerWithQueries(roleResolver MemberRoleResolver, queries CommandQueries) *Handler {
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, queries, nil, nil, nil)
}
// --- tests ---
@@ -73,6 +126,42 @@ func TestExecute_Help(t *testing.T) {
if !strings.Contains(result, "Available commands") {
t.Errorf("expected help text, got: %s", result)
}
if strings.Contains(result, "set-heartbeat") {
t.Errorf("top-level help should not expand nested actions, got: %s", result)
}
if !strings.Contains(result, "- /model - Manage bot models") {
t.Errorf("expected top-level model entry, got: %s", result)
}
}
func TestExecute_HelpGroup(t *testing.T) {
t.Parallel()
h := newTestHandler(&fakeRoleResolver{role: "owner"})
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "/model - Manage bot models") {
t.Errorf("expected group help, got: %s", result)
}
if !strings.Contains(result, "- set - Set the chat model [owner]") {
t.Errorf("expected compact action summary, got: %s", result)
}
}
func TestExecute_HelpAction(t *testing.T) {
t.Parallel()
h := newTestHandler(&fakeRoleResolver{role: "owner"})
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model set")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "Usage: /model set <model_id> | <provider_name> <model_name>") {
t.Errorf("expected action usage, got: %s", result)
}
if !strings.Contains(result, "Access: owner only") {
t.Errorf("expected owner hint, got: %s", result)
}
}
func TestExecute_UnknownCommand(t *testing.T) {
@@ -207,8 +296,8 @@ func TestFormatItems(t *testing.T) {
if !strings.Contains(result, "- foo") {
t.Errorf("expected '- foo' bullet, got: %s", result)
}
if !strings.Contains(result, " Type: bar") {
t.Errorf("expected indented 'Type: bar', got: %s", result)
if !strings.Contains(result, "- foo | Type: bar") {
t.Errorf("expected compact line entry, got: %s", result)
}
if !strings.Contains(result, "- longname") {
t.Errorf("expected '- longname' bullet, got: %s", result)
@@ -255,7 +344,7 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
for _, group := range []string{
"schedule", "mcp", "settings",
"model", "memory", "search", "browser", "usage",
"email", "heartbeat", "skill", "fs",
"email", "heartbeat", "skill", "fs", "access",
} {
if !strings.Contains(help, "/"+group) {
t.Errorf("missing /%s in global help", group)
@@ -263,6 +352,86 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
}
}
func TestExecuteWithInput_Access(t *testing.T) {
t.Parallel()
h := newTestHandler(&fakeRoleResolver{role: "owner"})
result, err := h.ExecuteWithInput(context.Background(), ExecuteInput{
BotID: "bot-1",
ChannelIdentityID: "channel-id-1",
UserID: "user-id-1",
Text: "/access",
ChannelType: "discord",
ConversationType: "thread",
ConversationID: "conv-1",
ThreadID: "thread-1",
RouteID: "route-1",
SessionID: "session-1",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "- Channel Identity: channel-id-1") {
t.Errorf("expected channel identity in access output, got: %s", result)
}
if !strings.Contains(result, "- Write Commands: yes") {
t.Errorf("expected write access in access output, got: %s", result)
}
}
func TestExecute_StatusLatest(t *testing.T) {
t.Parallel()
sessionUUID := pgtype.UUID{}
copy(sessionUUID.Bytes[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
sessionUUID.Valid = true
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
latestSessionID: sessionUUID,
messageCount: 42,
latestUsage: 1200,
cacheRow: dbsqlc.GetSessionCacheStatsRow{
CacheReadTokens: 300,
CacheWriteTokens: 150,
TotalInputTokens: 1200,
},
skills: []string{"search", "browser"},
})
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "- Scope: latest bot session") {
t.Errorf("expected latest scope, got: %s", result)
}
if !strings.Contains(result, "- Messages: 42") {
t.Errorf("expected message count, got: %s", result)
}
}
func TestExecute_StatusLatestNoRows(t *testing.T) {
t.Parallel()
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
latestSessionErr: pgx.ErrNoRows,
})
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "No session found for this bot.") {
t.Errorf("expected no session message, got: %s", result)
}
}
func TestExecute_StatusShowWithoutSession(t *testing.T) {
t.Parallel()
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{})
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/status")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "No active session found for this conversation.") {
t.Errorf("expected route-aware no session message, got: %s", result)
}
}
// Verify write commands are tagged with [owner] in usage.
func TestUsage_OwnerTag(t *testing.T) {
t.Parallel()
+1 -1
View File
@@ -40,7 +40,7 @@ func (h *Handler) buildHeartbeatGroup() *CommandGroup {
}
records = append(records, rec)
}
return formatItems(records), nil
return formatLimitedItems(records, 10, "Use the Web UI for older heartbeat logs."), nil
},
})
return g
+25 -1
View File
@@ -1,6 +1,13 @@
package command
import "context"
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/acl"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
)
// Skill represents a single skill loaded from a bot's container.
type Skill struct {
@@ -25,3 +32,20 @@ type ContainerFS interface {
ListDir(ctx context.Context, botID, path string) ([]FSEntry, error)
ReadFile(ctx context.Context, botID, path string) (string, error)
}
// CommandQueries captures the sqlc methods used by slash commands.
// *dbsqlc.Queries satisfies this interface directly.
type CommandQueries interface {
GetLatestSessionIDByBot(ctx context.Context, botID pgtype.UUID) (pgtype.UUID, error)
CountMessagesBySession(ctx context.Context, sessionID pgtype.UUID) (int64, error)
GetLatestAssistantUsage(ctx context.Context, sessionID pgtype.UUID) (int64, error)
GetSessionCacheStats(ctx context.Context, sessionID pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error)
GetSessionUsedSkills(ctx context.Context, sessionID pgtype.UUID) ([]string, error)
GetTokenUsageByDayAndType(ctx context.Context, arg dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error)
GetTokenUsageByModel(ctx context.Context, arg dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error)
}
// AccessEvaluator checks whether the current channel context may trigger chat.
type AccessEvaluator interface {
Evaluate(ctx context.Context, req acl.EvaluateRequest) (bool, error)
}
+1 -1
View File
@@ -27,7 +27,7 @@ func (h *Handler) buildMCPGroup() *CommandGroup {
{"Status", item.Status},
})
}
return formatItems(records), nil
return formatLimitedItems(records, defaultListLimit, "Use /mcp get <name> for full details."), nil
},
})
g.Register(SubCommand{
+39 -6
View File
@@ -13,6 +13,9 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
Name: "list",
Usage: "list - List all memory providers",
Handler: func(cc CommandContext) (string, error) {
if h.memProvService == nil {
return "Memory provider service is not available.", nil
}
items, err := h.memProvService.List(cc.Ctx)
if err != nil {
return "", err
@@ -20,18 +23,44 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
if len(items) == 0 {
return "No memory providers found.", nil
}
records := make([][]kv, 0, len(items))
settingsResp, _ := h.getBotSettings(cc)
currentRecords := make([][]kv, 0, 1)
otherRecords := make([][]kv, 0, len(items))
for _, item := range items {
def := ""
if item.IsDefault {
def = " (default)"
}
records = append(records, []kv{
{"Name", item.Name + def},
label := item.Name + def
record := []kv{
{"Name", label},
{"Provider", item.Provider},
})
}
if item.ID == settingsResp.MemoryProviderID {
label += " [current]"
record[0].value = label
currentRecords = append(currentRecords, record)
continue
}
otherRecords = append(otherRecords, record)
}
return formatItems(records), nil
currentRecords = append(currentRecords, otherRecords...)
records := currentRecords
return formatLimitedItems(records, defaultListLimit, "Use /memory current to inspect the active provider."), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show the current memory provider",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{{"Memory Provider", h.resolveMemoryProviderName(cc, settingsResp.MemoryProviderID)}}), nil
},
})
g.Register(SubCommand{
@@ -42,7 +71,11 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
if len(cc.Args) < 1 {
return "Usage: /memory set <name>", nil
}
if h.settingsService == nil {
return "Settings service is not available.", nil
}
name := cc.Args[0]
before, _ := h.getBotSettings(cc)
items, err := h.memProvService.List(cc.Ctx)
if err != nil {
return "", err
@@ -55,7 +88,7 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
if err != nil {
return "", err
}
return fmt.Sprintf("Memory provider set to %q.", item.Name), nil
return formatChangedValue("Memory provider", h.resolveMemoryProviderName(cc, before.MemoryProviderID), item.Name), nil
}
}
return fmt.Sprintf("Memory provider %q not found.", name), nil
+147 -13
View File
@@ -1,7 +1,9 @@
package command
import (
"errors"
"fmt"
"sort"
"strings"
"github.com/memohai/memoh/internal/models"
@@ -12,36 +14,81 @@ func (h *Handler) buildModelGroup() *CommandGroup {
g := newCommandGroup("model", "Manage bot models")
g.Register(SubCommand{
Name: "list",
Usage: "list - List all available chat models",
Usage: "list [provider_name] - List available chat models",
Handler: func(cc CommandContext) (string, error) {
if h.modelsService == nil {
return "Model service is not available.", nil
}
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
if err != nil {
return "", err
}
filterProvider := ""
if len(cc.Args) > 0 {
filterProvider = strings.TrimSpace(strings.Join(cc.Args, " "))
}
items = h.filterModelsByProvider(cc, items, filterProvider)
if len(items) == 0 {
if filterProvider != "" {
return fmt.Sprintf("No chat models found for provider %q.", filterProvider), nil
}
return "No chat models found.", nil
}
settingsResp, _ := h.getBotSettings(cc)
sort.SliceStable(items, func(i, j int) bool {
return modelSortRank(items[i], settingsResp) < modelSortRank(items[j], settingsResp)
})
records := make([][]kv, 0, len(items))
for _, item := range items {
provName := h.resolveProviderName(cc, item.ProviderID)
label := item.Name
markers := modelMarkers(item.ID, settingsResp)
if len(markers) > 0 {
label += " [" + strings.Join(markers, ", ") + "]"
}
records = append(records, []kv{
{"Model", item.Name},
{"Model", label},
{"Provider", provName},
{"Model ID", item.ModelID},
})
}
return formatItems(records), nil
hint := "Use /model current to inspect active selections."
if filterProvider == "" {
hint = "Use /model list <provider_name> to narrow results."
}
return formatLimitedItems(records, defaultListLimit, hint), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show current chat and heartbeat models",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{
{"Chat Model", h.resolveModelName(cc, settingsResp.ChatModelID)},
{"Heartbeat Model", h.resolveModelName(cc, settingsResp.HeartbeatModelID)},
}), nil
},
})
g.Register(SubCommand{
Name: "set",
Usage: "set <provider_name> <model_name> - Set the chat model",
Usage: "set <model_id> | <provider_name> <model_name> - Set the chat model",
IsWrite: true,
Handler: func(cc CommandContext) (string, error) {
if len(cc.Args) < 2 {
return "Usage: /model set <provider_name> <model_name>", nil
if len(cc.Args) < 1 {
return "Usage: /model set <model_id> | <provider_name> <model_name>", nil
}
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
if h.settingsService == nil {
return "Settings service is not available.", nil
}
before, _ := h.getBotSettings(cc)
modelResp, err := h.findModelForSelection(cc, cc.Args)
if err != nil {
return "", err
}
@@ -51,18 +98,22 @@ func (h *Handler) buildModelGroup() *CommandGroup {
if err != nil {
return "", err
}
return fmt.Sprintf("Chat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
return formatChangedValue("Chat model", h.resolveModelName(cc, before.ChatModelID), h.resolveModelName(cc, modelResp.ID)), nil
},
})
g.Register(SubCommand{
Name: "set-heartbeat",
Usage: "set-heartbeat <provider_name> <model_name> - Set the heartbeat model",
Usage: "set-heartbeat <model_id> | <provider_name> <model_name> - Set the heartbeat model",
IsWrite: true,
Handler: func(cc CommandContext) (string, error) {
if len(cc.Args) < 2 {
return "Usage: /model set-heartbeat <provider_name> <model_name>", nil
if len(cc.Args) < 1 {
return "Usage: /model set-heartbeat <model_id> | <provider_name> <model_name>", nil
}
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
if h.settingsService == nil {
return "Settings service is not available.", nil
}
before, _ := h.getBotSettings(cc)
modelResp, err := h.findModelForSelection(cc, cc.Args)
if err != nil {
return "", err
}
@@ -72,7 +123,7 @@ func (h *Handler) buildModelGroup() *CommandGroup {
if err != nil {
return "", err
}
return fmt.Sprintf("Heartbeat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
return formatChangedValue("Heartbeat model", h.resolveModelName(cc, before.HeartbeatModelID), h.resolveModelName(cc, modelResp.ID)), nil
},
})
return g
@@ -105,3 +156,86 @@ func (h *Handler) findModelByProviderAndName(cc CommandContext, providerName, mo
}
return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName)
}
func (h *Handler) findModelForSelection(cc CommandContext, args []string) (models.GetResponse, error) {
if h.modelsService == nil {
return models.GetResponse{}, errors.New("model service is not available")
}
if len(args) == 0 {
return models.GetResponse{}, errors.New("model identifier is required")
}
if len(args) == 1 {
return h.findModelByIDOrName(cc, args[0])
}
return h.findModelByProviderAndName(cc, args[0], strings.Join(args[1:], " "))
}
func (h *Handler) findModelByIDOrName(cc CommandContext, target string) (models.GetResponse, error) {
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
if err != nil {
return models.GetResponse{}, err
}
target = strings.TrimSpace(target)
if target == "" {
return models.GetResponse{}, errors.New("model identifier is required")
}
for _, item := range items {
if strings.EqualFold(item.ModelID, target) {
return item, nil
}
}
matches := make([]models.GetResponse, 0, 4)
for _, item := range items {
if strings.EqualFold(item.Name, target) {
matches = append(matches, item)
}
}
switch len(matches) {
case 0:
return models.GetResponse{}, fmt.Errorf("model %q not found", target)
case 1:
return matches[0], nil
default:
choices := make([]string, 0, len(matches))
for _, item := range matches {
choices = append(choices, fmt.Sprintf("%s/%s", h.resolveProviderName(cc, item.ProviderID), item.ModelID))
}
return models.GetResponse{}, fmt.Errorf("model %q is ambiguous; use a model ID or provider-qualified name (%s)", target, strings.Join(choices, ", "))
}
}
func (h *Handler) filterModelsByProvider(cc CommandContext, items []models.GetResponse, providerName string) []models.GetResponse {
providerName = strings.TrimSpace(providerName)
if providerName == "" {
return items
}
filtered := make([]models.GetResponse, 0, len(items))
for _, item := range items {
if strings.EqualFold(h.resolveProviderName(cc, item.ProviderID), providerName) {
filtered = append(filtered, item)
}
}
return filtered
}
func modelMarkers(modelID string, settingsResp settings.Settings) []string {
var markers []string
if modelID == settingsResp.ChatModelID {
markers = append(markers, "chat")
}
if modelID == settingsResp.HeartbeatModelID {
markers = append(markers, "heartbeat")
}
return markers
}
func modelSortRank(model models.GetResponse, settingsResp settings.Settings) int {
switch len(modelMarkers(model.ID, settingsResp)) {
case 2:
return 0
case 1:
return 1
default:
return 2
}
}
+2
View File
@@ -13,6 +13,8 @@ func TestParse_Basic(t *testing.T) {
args []string
}{
{"/help", "help", "", nil},
{"/help model", "help", "model", nil},
{"/help model set", "help", "model", []string{"set"}},
{"/subagent list", "subagent", "list", nil},
{"/subagent get mybot", "subagent", "get", []string{"mybot"}},
{"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}},
+78 -12
View File
@@ -8,10 +8,18 @@ import (
// CommandContext carries execution context for a sub-command.
type CommandContext struct {
Ctx context.Context
BotID string
Role string // "owner", "admin", "member", or "" (guest)
Args []string
Ctx context.Context
BotID string
Role string // "owner", "admin", "member", or "" (guest)
Args []string
ChannelIdentityID string
UserID string
ChannelType string
ConversationType string
ConversationID string
ThreadID string
RouteID string
SessionID string
}
// SubCommand describes a single sub-command within a resource group.
@@ -50,15 +58,38 @@ func (g *CommandGroup) Usage() string {
fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description)
for _, name := range g.order {
sub := g.commands[name]
perm := ""
desc := subSummary(sub)
if sub.IsWrite {
perm = " [owner]"
desc += " [owner]"
}
fmt.Fprintf(&b, "- %s%s\n", sub.Usage, perm)
fmt.Fprintf(&b, "- %s\n", desc)
}
fmt.Fprintf(&b, "\nUse /help %s <action> for details.", g.Name)
return b.String()
}
func (g *CommandGroup) ActionHelp(action string) string {
sub, ok := g.commands[action]
if !ok {
return fmt.Sprintf("Unknown action %q for /%s.\n\n%s", action, g.Name, g.Usage())
}
usage, summary := splitUsage(sub.Usage)
var b strings.Builder
fmt.Fprintf(&b, "/%s %s\n", g.Name, sub.Name)
if summary != "" {
fmt.Fprintf(&b, "- Summary: %s\n", summary)
}
if usage == "" {
usage = sub.Name
}
fmt.Fprintf(&b, "- Usage: /%s %s\n", g.Name, usage)
if sub.IsWrite {
b.WriteString("- Access: owner only\n")
}
fmt.Fprintf(&b, "- Tip: use /help %s to view sibling actions.", g.Name)
return strings.TrimRight(b.String(), "\n")
}
// Registry holds all registered command groups.
type Registry struct {
groups map[string]*CommandGroup
@@ -83,12 +114,47 @@ func (r *Registry) GlobalHelp() string {
b.WriteString("/help - Show this help message\n")
b.WriteString("/new - Start a new conversation (resets session context)\n")
b.WriteString("/stop - Stop the current generation\n\n")
for i, name := range r.order {
if i > 0 {
b.WriteByte('\n')
}
for _, name := range r.order {
group := r.groups[name]
b.WriteString(group.Usage())
fmt.Fprintf(&b, "- /%s - %s\n", group.Name, group.Description)
}
b.WriteString("\nUse /help <group> to view actions, e.g. /help model")
return strings.TrimRight(b.String(), "\n")
}
func (r *Registry) GroupHelp(name string) string {
group, ok := r.groups[name]
if !ok {
return fmt.Sprintf("Unknown command group: /%s\n\n%s", name, r.GlobalHelp())
}
return group.Usage()
}
func (r *Registry) ActionHelp(groupName, action string) string {
group, ok := r.groups[groupName]
if !ok {
return fmt.Sprintf("Unknown command group: /%s\n\n%s", groupName, r.GlobalHelp())
}
return group.ActionHelp(action)
}
func splitUsage(usage string) (commandUsage string, summary string) {
usage = strings.TrimSpace(usage)
if usage == "" {
return "", ""
}
parts := strings.SplitN(usage, " - ", 2)
commandUsage = strings.TrimSpace(parts[0])
if len(parts) > 1 {
summary = strings.TrimSpace(parts[1])
}
return commandUsage, summary
}
func subSummary(sub SubCommand) string {
usage, summary := splitUsage(sub.Usage)
if summary == "" {
return usage
}
return fmt.Sprintf("%s - %s", sub.Name, summary)
}
+39 -6
View File
@@ -13,6 +13,9 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
Name: "list",
Usage: "list - List all search providers",
Handler: func(cc CommandContext) (string, error) {
if h.searchProvService == nil {
return "Search provider service is not available.", nil
}
items, err := h.searchProvService.List(cc.Ctx, "")
if err != nil {
return "", err
@@ -20,14 +23,40 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
if len(items) == 0 {
return "No search providers found.", nil
}
records := make([][]kv, 0, len(items))
settingsResp, _ := h.getBotSettings(cc)
currentRecords := make([][]kv, 0, 1)
otherRecords := make([][]kv, 0, len(items))
for _, item := range items {
records = append(records, []kv{
{"Name", item.Name},
label := item.Name
record := []kv{
{"Name", label},
{"Provider", item.Provider},
})
}
if item.ID == settingsResp.SearchProviderID {
label += " [current]"
record[0].value = label
currentRecords = append(currentRecords, record)
continue
}
otherRecords = append(otherRecords, record)
}
return formatItems(records), nil
currentRecords = append(currentRecords, otherRecords...)
records := currentRecords
return formatLimitedItems(records, defaultListLimit, "Use /search current to inspect the active provider."), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show the current search provider",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{{"Search Provider", h.resolveSearchProviderName(cc, settingsResp.SearchProviderID)}}), nil
},
})
g.Register(SubCommand{
@@ -38,7 +67,11 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
if len(cc.Args) < 1 {
return "Usage: /search set <name>", nil
}
if h.settingsService == nil {
return "Settings service is not available.", nil
}
name := cc.Args[0]
before, _ := h.getBotSettings(cc)
items, err := h.searchProvService.List(cc.Ctx, "")
if err != nil {
return "", err
@@ -51,7 +84,7 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
if err != nil {
return "", err
}
return fmt.Sprintf("Search provider set to %q.", item.Name), nil
return formatChangedValue("Search provider", h.resolveSearchProviderName(cc, before.SearchProviderID), item.Name), nil
}
}
return fmt.Sprintf("Search provider %q not found.", name), nil
+8
View File
@@ -1,6 +1,7 @@
package command
import (
"errors"
"fmt"
"strconv"
"strings"
@@ -106,6 +107,13 @@ func settingsUpdateUsage() string {
"- --heartbeat_model_id <id>"
}
func (h *Handler) getBotSettings(cc CommandContext) (settings.Settings, error) {
if h.settingsService == nil {
return settings.Settings{}, errors.New("settings service is not available")
}
return h.settingsService.GetBot(cc.Ctx, cc.BotID)
}
// resolveModelName resolves a model UUID to "model_name (provider_name)".
func (h *Handler) resolveModelName(cc CommandContext, modelID string) string {
if modelID == "" {
+94 -57
View File
@@ -6,7 +6,9 @@ import (
"strconv"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
func (h *Handler) buildStatusGroup() *CommandGroup {
@@ -16,78 +18,113 @@ func (h *Handler) buildStatusGroup() *CommandGroup {
Name: "show",
Usage: "show - Show current session status",
Handler: func(cc CommandContext) (string, error) {
if strings.TrimSpace(cc.SessionID) == "" {
return "No active session found for this conversation.", nil
}
return h.renderSessionStatus(cc, cc.SessionID, "current conversation")
},
})
g.Register(SubCommand{
Name: "latest",
Usage: "latest - Show the latest session status for this bot",
Handler: func(cc CommandContext) (string, error) {
if h.queries == nil {
return "Session info is not available.", nil
}
botUUID, err := parseBotUUID(cc.BotID)
if err != nil {
return "", err
}
sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "No active session found.", nil
return "No session found for this bot.", nil
}
return "", err
}
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, sessionID)
if err != nil {
return "", fmt.Errorf("count messages: %w", err)
}
var usedTokens int64
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, sessionID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("get usage: %w", err)
}
if err == nil {
usedTokens = latestUsage
}
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, sessionID)
if err != nil {
return "", fmt.Errorf("get cache: %w", err)
}
var contextWindowStr string
if h.settingsService != nil {
s, sErr := h.settingsService.GetBot(cc.Ctx, cc.BotID)
if sErr == nil && s.ChatModelID != "" && h.modelsService != nil {
m, mErr := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
if mErr == nil && m.Config.ContextWindow != nil {
contextWindowStr = formatTokens(int64(*m.Config.ContextWindow))
}
}
}
var cacheHitRate float64
if cacheRow.TotalInputTokens > 0 {
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
}
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, sessionID)
var b strings.Builder
b.WriteString("Session Status:\n\n")
fmt.Fprintf(&b, "- Messages: %d\n", msgCount)
if contextWindowStr != "" {
fmt.Fprintf(&b, "- Context: %s / %s\n", formatTokens(usedTokens), contextWindowStr)
} else {
fmt.Fprintf(&b, "- Context: %s\n", formatTokens(usedTokens))
}
fmt.Fprintf(&b, "- Cache Hit Rate: %.1f%%\n", cacheHitRate)
fmt.Fprintf(&b, "- Cache Read: %s\n", formatTokens(cacheRow.CacheReadTokens))
fmt.Fprintf(&b, "- Cache Write: %s\n", formatTokens(cacheRow.CacheWriteTokens))
if len(skills) > 0 {
fmt.Fprintf(&b, "- Skills: %s\n", strings.Join(skills, ", "))
}
return strings.TrimRight(b.String(), "\n"), nil
return h.renderSessionStatus(cc, sessionID.String(), "latest bot session")
},
})
return g
}
func (h *Handler) renderSessionStatus(cc CommandContext, sessionID string, scope string) (string, error) {
if h.queries == nil {
return "Session info is not available.", nil
}
pgSessionID, err := parseCommandUUID(sessionID)
if err != nil {
return "", err
}
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, pgSessionID)
if err != nil {
return "", fmt.Errorf("count messages: %w", err)
}
var usedTokens int64
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, pgSessionID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("get usage: %w", err)
}
if err == nil {
usedTokens = latestUsage
}
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, pgSessionID)
if err != nil {
return "", fmt.Errorf("get cache: %w", err)
}
var cacheHitRate float64
if cacheRow.TotalInputTokens > 0 {
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
}
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, pgSessionID)
contextUsage := formatTokens(usedTokens)
if contextWindow := h.resolveContextWindow(cc); contextWindow != "" {
contextUsage = contextUsage + " / " + contextWindow
}
pairs := []kv{
{"Scope", scope},
{"Session ID", sessionID},
{"Messages", strconv.FormatInt(msgCount, 10)},
{"Context", contextUsage},
{"Cache Hit Rate", fmt.Sprintf("%.1f%%", cacheHitRate)},
{"Cache Read", formatTokens(cacheRow.CacheReadTokens)},
{"Cache Write", formatTokens(cacheRow.CacheWriteTokens)},
}
if len(skills) > 0 {
pairs = append(pairs, kv{"Skills", strings.Join(skills, ", ")})
}
return formatKV(pairs), nil
}
func (h *Handler) resolveContextWindow(cc CommandContext) string {
if h.settingsService == nil || h.modelsService == nil {
return ""
}
s, err := h.settingsService.GetBot(cc.Ctx, cc.BotID)
if err != nil || s.ChatModelID == "" {
return ""
}
m, err := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
if err != nil || m.Config.ContextWindow == nil {
return ""
}
return formatTokens(int64(*m.Config.ContextWindow))
}
func parseCommandUUID(id string) (pgtype.UUID, error) {
parsed, err := uuid.Parse(strings.TrimSpace(id))
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid uuid: %w", err)
}
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
}
func formatTokens(n int64) string {
if n >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
+6
View File
@@ -18,6 +18,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
Name: "summary",
Usage: "summary - Token usage summary (last 7 days)",
Handler: func(cc CommandContext) (string, error) {
if h.queries == nil {
return "Usage info is not available.", nil
}
botUUID, err := parseBotUUID(cc.BotID)
if err != nil {
return "", err
@@ -89,6 +92,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
Name: "by-model",
Usage: "by-model - Token usage grouped by model",
Handler: func(cc CommandContext) (string, error) {
if h.queries == nil {
return "Usage info is not available.", nil
}
botUUID, err := parseBotUUID(cc.BotID)
if err != nil {
return "", err