feat(command): improve slash command UX (#361)

Make slash commands easier to navigate in chat by splitting help into levels, compacting list output, and surfacing current selections for model, search, memory, and browser settings. Also route /status to the active conversation session and add an access inspector so users can understand their current command and ACL context.
This commit is contained in:
Acbox
2026-04-13 12:37:12 +08:00
committed by GitHub
parent c9c221e35d
commit d46269de89
22 changed files with 1080 additions and 138 deletions
+9
View File
@@ -604,6 +604,7 @@ func provideChannelRouter(
emailOutboxService, emailOutboxService,
heartbeatService, heartbeatService,
queries, queries,
aclService,
&commandSkillLoaderAdapter{handler: containerdHandler}, &commandSkillLoaderAdapter{handler: containerdHandler},
&commandContainerFSAdapter{manager: manager}, &commandContainerFSAdapter{manager: manager},
)) ))
@@ -779,6 +780,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
} }
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
if err != nil {
return inbound.SessionResult{}, err
}
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
}
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) { func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType) sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
if err != nil { if err != nil {
+9
View File
@@ -492,6 +492,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc
emailOutboxService, emailOutboxService,
heartbeatService, heartbeatService,
queries, queries,
aclService,
&commandSkillLoaderAdapter{handler: containerdHandler}, &commandSkillLoaderAdapter{handler: containerdHandler},
&commandContainerFSAdapter{manager: manager}, &commandContainerFSAdapter{manager: manager},
)) ))
@@ -877,6 +878,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
} }
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
if err != nil {
return inbound.SessionResult{}, err
}
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
}
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) { func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType) sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
if err != nil { if err != nil {
+106 -1
View File
@@ -71,6 +71,7 @@ type ttsModelResolver interface {
// SessionEnsurer resolves or creates an active session for a route. // SessionEnsurer resolves or creates an active session for a route.
type SessionEnsurer interface { type SessionEnsurer interface {
EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error) EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error)
GetActiveSession(ctx context.Context, routeID string) (SessionResult, error)
// CreateNewSession always creates a fresh session and sets it as the // CreateNewSession always creates a fresh session and sets it as the
// active session for the given route, replacing any previous one. // active session for the given route, replacing any previous one.
// sessionType defaults to "chat" if empty. // sessionType defaults to "chat" if empty.
@@ -298,11 +299,23 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
if isStopCommand(cmdText) && isDirectedAtBot(msg) { if isStopCommand(cmdText) && isDirectedAtBot(msg) {
return p.handleStopCommand(ctx, cfg, msg, sender, identity) return p.handleStopCommand(ctx, cfg, msg, sender, identity)
} }
if isStatusCommand(cmdText) && isDirectedAtBot(msg) {
return p.handleStatusCommand(ctx, cfg, msg, sender, identity)
}
// Skip generic command handler for mode-prefix commands (/btw, /now, /next) // Skip generic command handler for mode-prefix commands (/btw, /now, /next)
// so they pass through to mode detection below. // so they pass through to mode detection below.
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) { if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText) reply, err := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
BotID: strings.TrimSpace(identity.BotID),
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
UserID: strings.TrimSpace(identity.UserID),
Text: cmdText,
ChannelType: msg.Channel.String(),
ConversationType: strings.TrimSpace(msg.Conversation.Type),
ConversationID: strings.TrimSpace(msg.Conversation.ID),
ThreadID: extractThreadID(msg),
})
if err != nil { if err != nil {
reply = "Error: " + err.Error() reply = "Error: " + err.Error()
} }
@@ -2489,6 +2502,18 @@ func isNewSessionCommand(cmdText string) bool {
return parsed.Resource == "new" return parsed.Resource == "new"
} }
func isStatusCommand(cmdText string) bool {
extracted := command.ExtractCommandText(cmdText)
if extracted == "" {
return false
}
parsed, err := command.Parse(extracted)
if err != nil {
return false
}
return parsed.Resource == "status"
}
// resolveNewSessionType determines the session type for /new command. // resolveNewSessionType determines the session type for /new command.
// /new chat → chat, /new discuss → discuss, /new (no arg) → default by context. // /new chat → chat, /new discuss → discuss, /new (no arg) → default by context.
// WebUI (local channel) always defaults to chat. // WebUI (local channel) always defaults to chat.
@@ -2611,3 +2636,83 @@ func (p *ChannelInboundProcessor) handleNewSessionCommand(
Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)}, Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)},
}) })
} }
func (p *ChannelInboundProcessor) handleStatusCommand(
ctx context.Context,
cfg channel.ChannelConfig,
msg channel.InboundMessage,
sender channel.StreamReplySender,
identity InboundIdentity,
) error {
target := strings.TrimSpace(msg.ReplyTarget)
if target == "" {
return errors.New("reply target missing for /status command")
}
if p.routeResolver == nil {
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: "Error: route resolver not configured."},
})
}
if p.commandHandler == nil {
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: "Error: command handler not configured."},
})
}
threadID := extractThreadID(msg)
routeMetadata := buildRouteMetadata(msg, identity)
p.enrichConversationAvatar(ctx, cfg, msg, routeMetadata)
resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{
BotID: identity.BotID,
Platform: msg.Channel.String(),
ConversationID: msg.Conversation.ID,
ThreadID: threadID,
ConversationType: msg.Conversation.Type,
ChannelIdentityID: identity.UserID,
ChannelConfigID: identity.ChannelConfigID,
ReplyTarget: target,
Metadata: routeMetadata,
})
if err != nil {
if p.logger != nil {
p.logger.Warn("resolve route for /status command failed", slog.Any("error", err))
}
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: "Error: failed to resolve conversation route."},
})
}
sessionID := ""
if p.sessionEnsurer != nil {
sess, sessErr := p.sessionEnsurer.GetActiveSession(ctx, resolved.RouteID)
if sessErr == nil {
sessionID = strings.TrimSpace(sess.ID)
} else if p.logger != nil {
p.logger.Debug("resolve active session for /status command failed", slog.Any("error", sessErr))
}
}
reply, execErr := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
BotID: strings.TrimSpace(identity.BotID),
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
UserID: strings.TrimSpace(identity.UserID),
Text: rawTextForCommand(msg, strings.TrimSpace(msg.Message.PlainText())),
ChannelType: msg.Channel.String(),
ConversationType: strings.TrimSpace(msg.Conversation.Type),
ConversationID: strings.TrimSpace(msg.Conversation.ID),
ThreadID: threadID,
RouteID: strings.TrimSpace(resolved.RouteID),
SessionID: sessionID,
})
if execErr != nil {
reply = "Error: " + execErr.Error()
}
return sender.Send(ctx, channel.OutboundMessage{
Target: target,
Message: channel.Message{Text: reply},
})
}
+135
View File
@@ -13,11 +13,15 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/acl" "github.com/memohai/memoh/internal/acl"
"github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/channel/identities"
"github.com/memohai/memoh/internal/channel/route" "github.com/memohai/memoh/internal/channel/route"
"github.com/memohai/memoh/internal/command"
"github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/conversation"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/media" "github.com/memohai/memoh/internal/media"
messagepkg "github.com/memohai/memoh/internal/message" messagepkg "github.com/memohai/memoh/internal/message"
"github.com/memohai/memoh/internal/schedule" "github.com/memohai/memoh/internal/schedule"
@@ -185,6 +189,71 @@ type fakeChatACL struct {
lastReq acl.EvaluateRequest lastReq acl.EvaluateRequest
} }
type fakeSessionEnsurer struct {
activeSession SessionResult
activeErr error
lastRouteID string
}
func (f *fakeSessionEnsurer) EnsureActiveSession(_ context.Context, _, routeID, _ string) (SessionResult, error) {
f.lastRouteID = routeID
if f.activeErr != nil {
return SessionResult{}, f.activeErr
}
return f.activeSession, nil
}
func (f *fakeSessionEnsurer) GetActiveSession(_ context.Context, routeID string) (SessionResult, error) {
f.lastRouteID = routeID
if f.activeErr != nil {
return SessionResult{}, f.activeErr
}
return f.activeSession, nil
}
func (f *fakeSessionEnsurer) CreateNewSession(_ context.Context, _, routeID, _, _ string) (SessionResult, error) {
f.lastRouteID = routeID
if f.activeErr != nil {
return SessionResult{}, f.activeErr
}
return f.activeSession, nil
}
type fakeCommandQueries struct {
messageCount int64
usage int64
cacheRow dbsqlc.GetSessionCacheStatsRow
skills []string
}
func (*fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
return pgtype.UUID{}, errors.New("unexpected latest session lookup")
}
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
return f.messageCount, nil
}
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
return f.usage, nil
}
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
return f.cacheRow, nil
}
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
return f.skills, nil
}
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
return nil, nil
}
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
return nil, nil
}
func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) { func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) {
f.calls++ f.calls++
f.lastReq = req f.lastReq = req
@@ -548,6 +617,72 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) {
} }
} }
func TestChannelInboundProcessorStatusUsesRouteSession(t *testing.T) {
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-status"}}
policySvc := &fakePolicyService{}
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-status", RouteID: "route-status"}}
gateway := &fakeChatGateway{}
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
processor.SetSessionEnsurer(&fakeSessionEnsurer{
activeSession: SessionResult{ID: "11111111-1111-1111-1111-111111111111", Type: "chat"},
})
processor.SetCommandHandler(command.NewHandler(
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
&fakeCommandQueries{
messageCount: 9,
usage: 512,
cacheRow: dbsqlc.GetSessionCacheStatsRow{
CacheReadTokens: 64,
CacheWriteTokens: 32,
TotalInputTokens: 512,
},
skills: []string{"search"},
},
nil,
nil,
nil,
))
sender := &fakeReplySender{}
cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("discord")}
msg := channel.InboundMessage{
BotID: "bot-1",
Channel: channel.ChannelType("discord"),
Message: channel.Message{Text: "/status"},
ReplyTarget: "discord:status",
Sender: channel.Identity{SubjectID: "user-1"},
Conversation: channel.Conversation{
ID: "conv-status",
Type: channel.ConversationTypePrivate,
},
}
if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one status reply, got %d", len(sender.sent))
}
if !strings.Contains(sender.sent[0].Message.Text, "- Scope: current conversation") {
t.Fatalf("expected current conversation scope, got %q", sender.sent[0].Message.Text)
}
if !strings.Contains(sender.sent[0].Message.Text, "- Session ID: 11111111-1111-1111-1111-111111111111") {
t.Fatalf("expected active route session in reply, got %q", sender.sent[0].Message.Text)
}
}
func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) { func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) {
t.Parallel() t.Parallel()
+78
View File
@@ -0,0 +1,78 @@
package command
import (
"fmt"
"strings"
"github.com/memohai/memoh/internal/acl"
)
func (h *Handler) buildAccessGroup() *CommandGroup {
g := newCommandGroup("access", "Inspect identity and permission context")
g.DefaultAction = "show"
g.Register(SubCommand{
Name: "show",
Usage: "show - Show current identity, write access, and chat ACL context",
Handler: func(cc CommandContext) (string, error) {
writeAccess := "no"
if cc.Role == "owner" {
writeAccess = "yes"
}
pairs := []kv{
{"Channel Identity", fallbackValue(cc.ChannelIdentityID)},
{"Linked User", fallbackValue(cc.UserID)},
{"Bot Role", fallbackValue(cc.Role)},
{"Write Commands", writeAccess},
{"Channel", fallbackValue(cc.ChannelType)},
{"Conversation Type", fallbackValue(cc.ConversationType)},
{"Conversation ID", fallbackValue(cc.ConversationID)},
{"Thread ID", fallbackValue(cc.ThreadID)},
}
if strings.TrimSpace(cc.RouteID) != "" {
pairs = append(pairs, kv{"Route ID", cc.RouteID})
}
if strings.TrimSpace(cc.SessionID) != "" {
pairs = append(pairs, kv{"Session ID", cc.SessionID})
}
aclStatus := "unavailable"
if h.aclEvaluator != nil && strings.TrimSpace(cc.ChannelType) != "" {
allowed, err := h.aclEvaluator.Evaluate(cc.Ctx, acl.EvaluateRequest{
BotID: cc.BotID,
ChannelIdentityID: cc.ChannelIdentityID,
ChannelType: cc.ChannelType,
SourceScope: acl.SourceScope{
ConversationType: cc.ConversationType,
ConversationID: cc.ConversationID,
ThreadID: cc.ThreadID,
},
})
switch {
case err != nil:
aclStatus = "error: " + err.Error()
case allowed:
aclStatus = "allow"
default:
aclStatus = "deny"
}
}
pairs = append(pairs, kv{"Chat ACL", aclStatus})
return formatKV(pairs), nil
},
})
return g
}
func fallbackValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return "(none)"
}
return value
}
func formatChangedValue(label, before, after string) string {
return fmt.Sprintf("%s: %s -> %s", label, fallbackValue(before), fallbackValue(after))
}
+37 -6
View File
@@ -13,6 +13,9 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
Name: "list", Name: "list",
Usage: "list - List all browser contexts", Usage: "list - List all browser contexts",
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if h.browserCtxService == nil {
return "Browser context service is not available.", nil
}
items, err := h.browserCtxService.List(cc.Ctx) items, err := h.browserCtxService.List(cc.Ctx)
if err != nil { if err != nil {
return "", err return "", err
@@ -20,13 +23,37 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
if len(items) == 0 { if len(items) == 0 {
return "No browser contexts found.", nil return "No browser contexts found.", nil
} }
records := make([][]kv, 0, len(items)) settingsResp, _ := h.getBotSettings(cc)
currentRecords := make([][]kv, 0, 1)
otherRecords := make([][]kv, 0, len(items))
for _, item := range items { for _, item := range items {
records = append(records, []kv{ label := item.Name
{"Name", item.Name}, record := []kv{{"Name", label}}
}) if item.ID == settingsResp.BrowserContextID {
label += " [current]"
record[0].value = label
currentRecords = append(currentRecords, record)
continue
}
otherRecords = append(otherRecords, record)
} }
return formatItems(records), nil currentRecords = append(currentRecords, otherRecords...)
records := currentRecords
return formatLimitedItems(records, defaultListLimit, "Use /browser current to inspect the active context."), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show the current browser context",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{{"Browser Context", h.resolveBrowserContextName(cc, settingsResp.BrowserContextID)}}), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
@@ -37,7 +64,11 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
if len(cc.Args) < 1 { if len(cc.Args) < 1 {
return "Usage: /browser set <name>", nil return "Usage: /browser set <name>", nil
} }
if h.settingsService == nil {
return "Settings service is not available.", nil
}
name := cc.Args[0] name := cc.Args[0]
before, _ := h.getBotSettings(cc)
items, err := h.browserCtxService.List(cc.Ctx) items, err := h.browserCtxService.List(cc.Ctx)
if err != nil { if err != nil {
return "", err return "", err
@@ -50,7 +81,7 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("Browser context set to %q.", item.Name), nil return formatChangedValue("Browser context", h.resolveBrowserContextName(cc, before.BrowserContextID), item.Name), nil
} }
} }
return fmt.Sprintf("Browser context %q not found.", name), nil return fmt.Sprintf("Browser context %q not found.", name), nil
+1
View File
@@ -16,5 +16,6 @@ func (h *Handler) buildRegistry() *Registry {
r.RegisterGroup(h.buildSkillGroup()) r.RegisterGroup(h.buildSkillGroup())
r.RegisterGroup(h.buildFSGroup()) r.RegisterGroup(h.buildFSGroup())
r.RegisterGroup(h.buildStatusGroup()) r.RegisterGroup(h.buildStatusGroup())
r.RegisterGroup(h.buildAccessGroup())
return r return r
} }
+3 -3
View File
@@ -24,7 +24,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
{"Provider", item.Provider}, {"Provider", item.Provider},
}) })
} }
return formatItems(records), nil return formatLimitedItems(records, defaultListLimit, "Use /email bindings to inspect bot bindings."), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
@@ -46,7 +46,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
{"Permissions", perms}, {"Permissions", perms},
}) })
} }
return formatItems(records), nil return formatLimitedItems(records, defaultListLimit, "Use /email outbox to inspect recent sends."), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
@@ -70,7 +70,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
{"Sent", item.SentAt.Format("01-02 15:04")}, {"Sent", item.SentAt.Format("01-02 15:04")},
}) })
} }
return formatItems(records), nil return formatLimitedItems(records, 10, "Use the Web UI for older outbox entries."), nil
}, },
}) })
return g return g
+32 -14
View File
@@ -6,19 +6,10 @@ import (
"unicode/utf8" "unicode/utf8"
) )
const defaultListLimit = 12
// formatItems renders a list of records as a Markdown-style list. // formatItems renders a list of records as a Markdown-style list.
// Each record is a slice of kv pairs; the first pair's value is used as the // Each record is rendered on a single line so long lists stay readable in IM.
// bullet title, and subsequent pairs are indented beneath it.
//
// Example output:
//
// - mybot
// Description: A helpful assistant
// ID: abc123
//
// - another
// Description: Something else
// ID: def456
func formatItems(items [][]kv) string { func formatItems(items [][]kv) string {
if len(items) == 0 { if len(items) == 0 {
return "" return ""
@@ -31,14 +22,41 @@ func formatItems(items [][]kv) string {
if i > 0 { if i > 0 {
b.WriteByte('\n') b.WriteByte('\n')
} }
fmt.Fprintf(&b, "- %s\n", record[0].value) fmt.Fprintf(&b, "- %s", record[0].value)
extras := make([]string, 0, len(record)-1)
for _, pair := range record[1:] { for _, pair := range record[1:] {
fmt.Fprintf(&b, " %s: %s\n", pair.key, pair.value) if strings.TrimSpace(pair.value) == "" {
continue
}
extras = append(extras, fmt.Sprintf("%s: %s", pair.key, pair.value))
}
if len(extras) > 0 {
fmt.Fprintf(&b, " | %s", strings.Join(extras, " | "))
} }
} }
return b.String() return b.String()
} }
func formatLimitedItems(items [][]kv, limit int, hint string) string {
if len(items) == 0 {
return ""
}
if limit <= 0 {
limit = defaultListLimit
}
total := len(items)
if total <= limit {
return formatItems(items)
}
shown := items[:limit]
result := formatItems(shown)
suffix := fmt.Sprintf("Showing %d of %d items.", len(shown), total)
if strings.TrimSpace(hint) != "" {
suffix += " " + strings.TrimSpace(hint)
}
return result + "\n\n" + suffix
}
// formatKV renders key-value pairs as a simple Markdown list. // formatKV renders key-value pairs as a simple Markdown list.
// //
// Example output: // Example output:
+57 -13
View File
@@ -4,10 +4,10 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"strings"
"github.com/memohai/memoh/internal/bots" "github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/browsercontexts" "github.com/memohai/memoh/internal/browsercontexts"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
emailpkg "github.com/memohai/memoh/internal/email" emailpkg "github.com/memohai/memoh/internal/email"
"github.com/memohai/memoh/internal/heartbeat" "github.com/memohai/memoh/internal/heartbeat"
"github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/mcp"
@@ -56,13 +56,28 @@ type Handler struct {
emailService *emailpkg.Service emailService *emailpkg.Service
emailOutboxService *emailpkg.OutboxService emailOutboxService *emailpkg.OutboxService
heartbeatService *heartbeat.Service heartbeatService *heartbeat.Service
queries *dbsqlc.Queries queries CommandQueries
aclEvaluator AccessEvaluator
skillLoader SkillLoader skillLoader SkillLoader
containerFS ContainerFS containerFS ContainerFS
logger *slog.Logger logger *slog.Logger
} }
// ExecuteInput carries the caller identity and channel context for command execution.
type ExecuteInput struct {
BotID string
ChannelIdentityID string
UserID string
Text string
ChannelType string
ConversationType string
ConversationID string
ThreadID string
RouteID string
SessionID string
}
// NewHandler creates a Handler with all required services. // NewHandler creates a Handler with all required services.
func NewHandler( func NewHandler(
log *slog.Logger, log *slog.Logger,
@@ -78,7 +93,8 @@ func NewHandler(
emailService *emailpkg.Service, emailService *emailpkg.Service,
emailOutboxService *emailpkg.OutboxService, emailOutboxService *emailpkg.OutboxService,
heartbeatService *heartbeat.Service, heartbeatService *heartbeat.Service,
queries *dbsqlc.Queries, queries CommandQueries,
aclEvaluator AccessEvaluator,
skillLoader SkillLoader, skillLoader SkillLoader,
containerFS ContainerFS, containerFS ContainerFS,
) *Handler { ) *Handler {
@@ -99,6 +115,7 @@ func NewHandler(
emailOutboxService: emailOutboxService, emailOutboxService: emailOutboxService,
heartbeatService: heartbeatService, heartbeatService: heartbeatService,
queries: queries, queries: queries,
aclEvaluator: aclEvaluator,
skillLoader: skillLoader, skillLoader: skillLoader,
containerFS: containerFS, containerFS: containerFS,
logger: log.With(slog.String("component", "command")), logger: log.With(slog.String("component", "command")),
@@ -140,7 +157,16 @@ func (h *Handler) IsCommand(text string) bool {
// Execute parses and runs a slash command, returning the text reply. // Execute parses and runs a slash command, returning the text reply.
func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) { func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) {
cmdText := ExtractCommandText(text) return h.ExecuteWithInput(ctx, ExecuteInput{
BotID: botID,
ChannelIdentityID: channelIdentityID,
Text: text,
})
}
// ExecuteWithInput parses and runs a slash command with channel/session context.
func (h *Handler) ExecuteWithInput(ctx context.Context, input ExecuteInput) (string, error) {
cmdText := ExtractCommandText(input.Text)
if cmdText == "" { if cmdText == "" {
return h.registry.GlobalHelp(), nil return h.registry.GlobalHelp(), nil
} }
@@ -151,12 +177,16 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
// Resolve the user's role in this bot. // Resolve the user's role in this bot.
role := "" role := ""
if h.roleResolver != nil && channelIdentityID != "" { roleIdentityID := input.ChannelIdentityID
r, err := h.roleResolver.GetMemberRole(ctx, botID, channelIdentityID) if strings.TrimSpace(input.UserID) != "" {
roleIdentityID = strings.TrimSpace(input.UserID)
}
if h.roleResolver != nil && roleIdentityID != "" {
r, err := h.roleResolver.GetMemberRole(ctx, input.BotID, roleIdentityID)
if err != nil { if err != nil {
h.logger.Warn("failed to resolve member role", h.logger.Warn("failed to resolve member role",
slog.String("bot_id", botID), slog.String("bot_id", input.BotID),
slog.String("channel_identity_id", channelIdentityID), slog.String("role_identity_id", roleIdentityID),
slog.Any("error", err), slog.Any("error", err),
) )
} else { } else {
@@ -165,15 +195,29 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
} }
cc := CommandContext{ cc := CommandContext{
Ctx: ctx, Ctx: ctx,
BotID: botID, BotID: input.BotID,
Role: role, Role: role,
Args: parsed.Args, Args: parsed.Args,
ChannelIdentityID: strings.TrimSpace(input.ChannelIdentityID),
UserID: strings.TrimSpace(input.UserID),
ChannelType: strings.TrimSpace(input.ChannelType),
ConversationType: strings.TrimSpace(input.ConversationType),
ConversationID: strings.TrimSpace(input.ConversationID),
ThreadID: strings.TrimSpace(input.ThreadID),
RouteID: strings.TrimSpace(input.RouteID),
SessionID: strings.TrimSpace(input.SessionID),
} }
// /help // /help
if parsed.Resource == "help" { if parsed.Resource == "help" {
return h.registry.GlobalHelp(), nil if parsed.Action == "" {
return h.registry.GlobalHelp(), nil
}
if len(parsed.Args) == 0 {
return h.registry.GroupHelp(parsed.Action), nil
}
return h.registry.ActionHelp(parsed.Action, parsed.Args[0]), nil
} }
// Top-level commands (e.g. /new) are handled by the channel inbound // Top-level commands (e.g. /new) are handled by the channel inbound
+173 -4
View File
@@ -5,6 +5,10 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/mcp"
"github.com/memohai/memoh/internal/schedule" "github.com/memohai/memoh/internal/schedule"
"github.com/memohai/memoh/internal/settings" "github.com/memohai/memoh/internal/settings"
@@ -25,9 +29,58 @@ type fakeScheduleService struct {
items []schedule.Schedule items []schedule.Schedule
} }
type fakeCommandQueries struct {
latestSessionID pgtype.UUID
latestSessionErr error
messageCount int64
latestUsage int64
latestUsageErr error
cacheRow dbsqlc.GetSessionCacheStatsRow
cacheErr error
skills []string
}
func (f *fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
return f.latestSessionID, f.latestSessionErr
}
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
return f.messageCount, nil
}
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
if f.latestUsageErr != nil {
return 0, f.latestUsageErr
}
return f.latestUsage, nil
}
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
if f.cacheErr != nil {
return dbsqlc.GetSessionCacheStatsRow{}, f.cacheErr
}
return f.cacheRow, nil
}
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
return f.skills, nil
}
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
return nil, nil
}
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
return nil, nil
}
// newTestHandler creates a Handler with nil services for use in tests. // newTestHandler creates a Handler with nil services for use in tests.
func newTestHandler(roleResolver MemberRoleResolver) *Handler { func newTestHandler(roleResolver MemberRoleResolver) *Handler {
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
}
func newTestHandlerWithQueries(roleResolver MemberRoleResolver, queries CommandQueries) *Handler {
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, queries, nil, nil, nil)
} }
// --- tests --- // --- tests ---
@@ -73,6 +126,42 @@ func TestExecute_Help(t *testing.T) {
if !strings.Contains(result, "Available commands") { if !strings.Contains(result, "Available commands") {
t.Errorf("expected help text, got: %s", result) t.Errorf("expected help text, got: %s", result)
} }
if strings.Contains(result, "set-heartbeat") {
t.Errorf("top-level help should not expand nested actions, got: %s", result)
}
if !strings.Contains(result, "- /model - Manage bot models") {
t.Errorf("expected top-level model entry, got: %s", result)
}
}
func TestExecute_HelpGroup(t *testing.T) {
t.Parallel()
h := newTestHandler(&fakeRoleResolver{role: "owner"})
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "/model - Manage bot models") {
t.Errorf("expected group help, got: %s", result)
}
if !strings.Contains(result, "- set - Set the chat model [owner]") {
t.Errorf("expected compact action summary, got: %s", result)
}
}
func TestExecute_HelpAction(t *testing.T) {
t.Parallel()
h := newTestHandler(&fakeRoleResolver{role: "owner"})
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model set")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "Usage: /model set <model_id> | <provider_name> <model_name>") {
t.Errorf("expected action usage, got: %s", result)
}
if !strings.Contains(result, "Access: owner only") {
t.Errorf("expected owner hint, got: %s", result)
}
} }
func TestExecute_UnknownCommand(t *testing.T) { func TestExecute_UnknownCommand(t *testing.T) {
@@ -207,8 +296,8 @@ func TestFormatItems(t *testing.T) {
if !strings.Contains(result, "- foo") { if !strings.Contains(result, "- foo") {
t.Errorf("expected '- foo' bullet, got: %s", result) t.Errorf("expected '- foo' bullet, got: %s", result)
} }
if !strings.Contains(result, " Type: bar") { if !strings.Contains(result, "- foo | Type: bar") {
t.Errorf("expected indented 'Type: bar', got: %s", result) t.Errorf("expected compact line entry, got: %s", result)
} }
if !strings.Contains(result, "- longname") { if !strings.Contains(result, "- longname") {
t.Errorf("expected '- longname' bullet, got: %s", result) t.Errorf("expected '- longname' bullet, got: %s", result)
@@ -255,7 +344,7 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
for _, group := range []string{ for _, group := range []string{
"schedule", "mcp", "settings", "schedule", "mcp", "settings",
"model", "memory", "search", "browser", "usage", "model", "memory", "search", "browser", "usage",
"email", "heartbeat", "skill", "fs", "email", "heartbeat", "skill", "fs", "access",
} { } {
if !strings.Contains(help, "/"+group) { if !strings.Contains(help, "/"+group) {
t.Errorf("missing /%s in global help", group) t.Errorf("missing /%s in global help", group)
@@ -263,6 +352,86 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
} }
} }
func TestExecuteWithInput_Access(t *testing.T) {
t.Parallel()
h := newTestHandler(&fakeRoleResolver{role: "owner"})
result, err := h.ExecuteWithInput(context.Background(), ExecuteInput{
BotID: "bot-1",
ChannelIdentityID: "channel-id-1",
UserID: "user-id-1",
Text: "/access",
ChannelType: "discord",
ConversationType: "thread",
ConversationID: "conv-1",
ThreadID: "thread-1",
RouteID: "route-1",
SessionID: "session-1",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "- Channel Identity: channel-id-1") {
t.Errorf("expected channel identity in access output, got: %s", result)
}
if !strings.Contains(result, "- Write Commands: yes") {
t.Errorf("expected write access in access output, got: %s", result)
}
}
func TestExecute_StatusLatest(t *testing.T) {
t.Parallel()
sessionUUID := pgtype.UUID{}
copy(sessionUUID.Bytes[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
sessionUUID.Valid = true
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
latestSessionID: sessionUUID,
messageCount: 42,
latestUsage: 1200,
cacheRow: dbsqlc.GetSessionCacheStatsRow{
CacheReadTokens: 300,
CacheWriteTokens: 150,
TotalInputTokens: 1200,
},
skills: []string{"search", "browser"},
})
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "- Scope: latest bot session") {
t.Errorf("expected latest scope, got: %s", result)
}
if !strings.Contains(result, "- Messages: 42") {
t.Errorf("expected message count, got: %s", result)
}
}
func TestExecute_StatusLatestNoRows(t *testing.T) {
t.Parallel()
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
latestSessionErr: pgx.ErrNoRows,
})
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "No session found for this bot.") {
t.Errorf("expected no session message, got: %s", result)
}
}
func TestExecute_StatusShowWithoutSession(t *testing.T) {
t.Parallel()
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{})
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/status")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "No active session found for this conversation.") {
t.Errorf("expected route-aware no session message, got: %s", result)
}
}
// Verify write commands are tagged with [owner] in usage. // Verify write commands are tagged with [owner] in usage.
func TestUsage_OwnerTag(t *testing.T) { func TestUsage_OwnerTag(t *testing.T) {
t.Parallel() t.Parallel()
+1 -1
View File
@@ -40,7 +40,7 @@ func (h *Handler) buildHeartbeatGroup() *CommandGroup {
} }
records = append(records, rec) records = append(records, rec)
} }
return formatItems(records), nil return formatLimitedItems(records, 10, "Use the Web UI for older heartbeat logs."), nil
}, },
}) })
return g return g
+25 -1
View File
@@ -1,6 +1,13 @@
package command package command
import "context" import (
"context"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/acl"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
)
// Skill represents a single skill loaded from a bot's container. // Skill represents a single skill loaded from a bot's container.
type Skill struct { type Skill struct {
@@ -25,3 +32,20 @@ type ContainerFS interface {
ListDir(ctx context.Context, botID, path string) ([]FSEntry, error) ListDir(ctx context.Context, botID, path string) ([]FSEntry, error)
ReadFile(ctx context.Context, botID, path string) (string, error) ReadFile(ctx context.Context, botID, path string) (string, error)
} }
// CommandQueries captures the sqlc methods used by slash commands.
// *dbsqlc.Queries satisfies this interface directly.
type CommandQueries interface {
GetLatestSessionIDByBot(ctx context.Context, botID pgtype.UUID) (pgtype.UUID, error)
CountMessagesBySession(ctx context.Context, sessionID pgtype.UUID) (int64, error)
GetLatestAssistantUsage(ctx context.Context, sessionID pgtype.UUID) (int64, error)
GetSessionCacheStats(ctx context.Context, sessionID pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error)
GetSessionUsedSkills(ctx context.Context, sessionID pgtype.UUID) ([]string, error)
GetTokenUsageByDayAndType(ctx context.Context, arg dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error)
GetTokenUsageByModel(ctx context.Context, arg dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error)
}
// AccessEvaluator checks whether the current channel context may trigger chat.
type AccessEvaluator interface {
Evaluate(ctx context.Context, req acl.EvaluateRequest) (bool, error)
}
+1 -1
View File
@@ -27,7 +27,7 @@ func (h *Handler) buildMCPGroup() *CommandGroup {
{"Status", item.Status}, {"Status", item.Status},
}) })
} }
return formatItems(records), nil return formatLimitedItems(records, defaultListLimit, "Use /mcp get <name> for full details."), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
+39 -6
View File
@@ -13,6 +13,9 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
Name: "list", Name: "list",
Usage: "list - List all memory providers", Usage: "list - List all memory providers",
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if h.memProvService == nil {
return "Memory provider service is not available.", nil
}
items, err := h.memProvService.List(cc.Ctx) items, err := h.memProvService.List(cc.Ctx)
if err != nil { if err != nil {
return "", err return "", err
@@ -20,18 +23,44 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
if len(items) == 0 { if len(items) == 0 {
return "No memory providers found.", nil return "No memory providers found.", nil
} }
records := make([][]kv, 0, len(items)) settingsResp, _ := h.getBotSettings(cc)
currentRecords := make([][]kv, 0, 1)
otherRecords := make([][]kv, 0, len(items))
for _, item := range items { for _, item := range items {
def := "" def := ""
if item.IsDefault { if item.IsDefault {
def = " (default)" def = " (default)"
} }
records = append(records, []kv{ label := item.Name + def
{"Name", item.Name + def}, record := []kv{
{"Name", label},
{"Provider", item.Provider}, {"Provider", item.Provider},
}) }
if item.ID == settingsResp.MemoryProviderID {
label += " [current]"
record[0].value = label
currentRecords = append(currentRecords, record)
continue
}
otherRecords = append(otherRecords, record)
} }
return formatItems(records), nil currentRecords = append(currentRecords, otherRecords...)
records := currentRecords
return formatLimitedItems(records, defaultListLimit, "Use /memory current to inspect the active provider."), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show the current memory provider",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{{"Memory Provider", h.resolveMemoryProviderName(cc, settingsResp.MemoryProviderID)}}), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
@@ -42,7 +71,11 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
if len(cc.Args) < 1 { if len(cc.Args) < 1 {
return "Usage: /memory set <name>", nil return "Usage: /memory set <name>", nil
} }
if h.settingsService == nil {
return "Settings service is not available.", nil
}
name := cc.Args[0] name := cc.Args[0]
before, _ := h.getBotSettings(cc)
items, err := h.memProvService.List(cc.Ctx) items, err := h.memProvService.List(cc.Ctx)
if err != nil { if err != nil {
return "", err return "", err
@@ -55,7 +88,7 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("Memory provider set to %q.", item.Name), nil return formatChangedValue("Memory provider", h.resolveMemoryProviderName(cc, before.MemoryProviderID), item.Name), nil
} }
} }
return fmt.Sprintf("Memory provider %q not found.", name), nil return fmt.Sprintf("Memory provider %q not found.", name), nil
+147 -13
View File
@@ -1,7 +1,9 @@
package command package command
import ( import (
"errors"
"fmt" "fmt"
"sort"
"strings" "strings"
"github.com/memohai/memoh/internal/models" "github.com/memohai/memoh/internal/models"
@@ -12,36 +14,81 @@ func (h *Handler) buildModelGroup() *CommandGroup {
g := newCommandGroup("model", "Manage bot models") g := newCommandGroup("model", "Manage bot models")
g.Register(SubCommand{ g.Register(SubCommand{
Name: "list", Name: "list",
Usage: "list - List all available chat models", Usage: "list [provider_name] - List available chat models",
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if h.modelsService == nil {
return "Model service is not available.", nil
}
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat) items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
if err != nil { if err != nil {
return "", err return "", err
} }
filterProvider := ""
if len(cc.Args) > 0 {
filterProvider = strings.TrimSpace(strings.Join(cc.Args, " "))
}
items = h.filterModelsByProvider(cc, items, filterProvider)
if len(items) == 0 { if len(items) == 0 {
if filterProvider != "" {
return fmt.Sprintf("No chat models found for provider %q.", filterProvider), nil
}
return "No chat models found.", nil return "No chat models found.", nil
} }
settingsResp, _ := h.getBotSettings(cc)
sort.SliceStable(items, func(i, j int) bool {
return modelSortRank(items[i], settingsResp) < modelSortRank(items[j], settingsResp)
})
records := make([][]kv, 0, len(items)) records := make([][]kv, 0, len(items))
for _, item := range items { for _, item := range items {
provName := h.resolveProviderName(cc, item.ProviderID) provName := h.resolveProviderName(cc, item.ProviderID)
label := item.Name
markers := modelMarkers(item.ID, settingsResp)
if len(markers) > 0 {
label += " [" + strings.Join(markers, ", ") + "]"
}
records = append(records, []kv{ records = append(records, []kv{
{"Model", item.Name}, {"Model", label},
{"Provider", provName}, {"Provider", provName},
{"Model ID", item.ModelID}, {"Model ID", item.ModelID},
}) })
} }
return formatItems(records), nil hint := "Use /model current to inspect active selections."
if filterProvider == "" {
hint = "Use /model list <provider_name> to narrow results."
}
return formatLimitedItems(records, defaultListLimit, hint), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show current chat and heartbeat models",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{
{"Chat Model", h.resolveModelName(cc, settingsResp.ChatModelID)},
{"Heartbeat Model", h.resolveModelName(cc, settingsResp.HeartbeatModelID)},
}), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
Name: "set", Name: "set",
Usage: "set <provider_name> <model_name> - Set the chat model", Usage: "set <model_id> | <provider_name> <model_name> - Set the chat model",
IsWrite: true, IsWrite: true,
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if len(cc.Args) < 2 { if len(cc.Args) < 1 {
return "Usage: /model set <provider_name> <model_name>", nil return "Usage: /model set <model_id> | <provider_name> <model_name>", nil
} }
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1]) if h.settingsService == nil {
return "Settings service is not available.", nil
}
before, _ := h.getBotSettings(cc)
modelResp, err := h.findModelForSelection(cc, cc.Args)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -51,18 +98,22 @@ func (h *Handler) buildModelGroup() *CommandGroup {
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("Chat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil return formatChangedValue("Chat model", h.resolveModelName(cc, before.ChatModelID), h.resolveModelName(cc, modelResp.ID)), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
Name: "set-heartbeat", Name: "set-heartbeat",
Usage: "set-heartbeat <provider_name> <model_name> - Set the heartbeat model", Usage: "set-heartbeat <model_id> | <provider_name> <model_name> - Set the heartbeat model",
IsWrite: true, IsWrite: true,
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if len(cc.Args) < 2 { if len(cc.Args) < 1 {
return "Usage: /model set-heartbeat <provider_name> <model_name>", nil return "Usage: /model set-heartbeat <model_id> | <provider_name> <model_name>", nil
} }
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1]) if h.settingsService == nil {
return "Settings service is not available.", nil
}
before, _ := h.getBotSettings(cc)
modelResp, err := h.findModelForSelection(cc, cc.Args)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -72,7 +123,7 @@ func (h *Handler) buildModelGroup() *CommandGroup {
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("Heartbeat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil return formatChangedValue("Heartbeat model", h.resolveModelName(cc, before.HeartbeatModelID), h.resolveModelName(cc, modelResp.ID)), nil
}, },
}) })
return g return g
@@ -105,3 +156,86 @@ func (h *Handler) findModelByProviderAndName(cc CommandContext, providerName, mo
} }
return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName) return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName)
} }
func (h *Handler) findModelForSelection(cc CommandContext, args []string) (models.GetResponse, error) {
if h.modelsService == nil {
return models.GetResponse{}, errors.New("model service is not available")
}
if len(args) == 0 {
return models.GetResponse{}, errors.New("model identifier is required")
}
if len(args) == 1 {
return h.findModelByIDOrName(cc, args[0])
}
return h.findModelByProviderAndName(cc, args[0], strings.Join(args[1:], " "))
}
func (h *Handler) findModelByIDOrName(cc CommandContext, target string) (models.GetResponse, error) {
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
if err != nil {
return models.GetResponse{}, err
}
target = strings.TrimSpace(target)
if target == "" {
return models.GetResponse{}, errors.New("model identifier is required")
}
for _, item := range items {
if strings.EqualFold(item.ModelID, target) {
return item, nil
}
}
matches := make([]models.GetResponse, 0, 4)
for _, item := range items {
if strings.EqualFold(item.Name, target) {
matches = append(matches, item)
}
}
switch len(matches) {
case 0:
return models.GetResponse{}, fmt.Errorf("model %q not found", target)
case 1:
return matches[0], nil
default:
choices := make([]string, 0, len(matches))
for _, item := range matches {
choices = append(choices, fmt.Sprintf("%s/%s", h.resolveProviderName(cc, item.ProviderID), item.ModelID))
}
return models.GetResponse{}, fmt.Errorf("model %q is ambiguous; use a model ID or provider-qualified name (%s)", target, strings.Join(choices, ", "))
}
}
func (h *Handler) filterModelsByProvider(cc CommandContext, items []models.GetResponse, providerName string) []models.GetResponse {
providerName = strings.TrimSpace(providerName)
if providerName == "" {
return items
}
filtered := make([]models.GetResponse, 0, len(items))
for _, item := range items {
if strings.EqualFold(h.resolveProviderName(cc, item.ProviderID), providerName) {
filtered = append(filtered, item)
}
}
return filtered
}
func modelMarkers(modelID string, settingsResp settings.Settings) []string {
var markers []string
if modelID == settingsResp.ChatModelID {
markers = append(markers, "chat")
}
if modelID == settingsResp.HeartbeatModelID {
markers = append(markers, "heartbeat")
}
return markers
}
func modelSortRank(model models.GetResponse, settingsResp settings.Settings) int {
switch len(modelMarkers(model.ID, settingsResp)) {
case 2:
return 0
case 1:
return 1
default:
return 2
}
}
+2
View File
@@ -13,6 +13,8 @@ func TestParse_Basic(t *testing.T) {
args []string args []string
}{ }{
{"/help", "help", "", nil}, {"/help", "help", "", nil},
{"/help model", "help", "model", nil},
{"/help model set", "help", "model", []string{"set"}},
{"/subagent list", "subagent", "list", nil}, {"/subagent list", "subagent", "list", nil},
{"/subagent get mybot", "subagent", "get", []string{"mybot"}}, {"/subagent get mybot", "subagent", "get", []string{"mybot"}},
{"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}}, {"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}},
+78 -12
View File
@@ -8,10 +8,18 @@ import (
// CommandContext carries execution context for a sub-command. // CommandContext carries execution context for a sub-command.
type CommandContext struct { type CommandContext struct {
Ctx context.Context Ctx context.Context
BotID string BotID string
Role string // "owner", "admin", "member", or "" (guest) Role string // "owner", "admin", "member", or "" (guest)
Args []string Args []string
ChannelIdentityID string
UserID string
ChannelType string
ConversationType string
ConversationID string
ThreadID string
RouteID string
SessionID string
} }
// SubCommand describes a single sub-command within a resource group. // SubCommand describes a single sub-command within a resource group.
@@ -50,15 +58,38 @@ func (g *CommandGroup) Usage() string {
fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description) fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description)
for _, name := range g.order { for _, name := range g.order {
sub := g.commands[name] sub := g.commands[name]
perm := "" desc := subSummary(sub)
if sub.IsWrite { if sub.IsWrite {
perm = " [owner]" desc += " [owner]"
} }
fmt.Fprintf(&b, "- %s%s\n", sub.Usage, perm) fmt.Fprintf(&b, "- %s\n", desc)
} }
fmt.Fprintf(&b, "\nUse /help %s <action> for details.", g.Name)
return b.String() return b.String()
} }
func (g *CommandGroup) ActionHelp(action string) string {
sub, ok := g.commands[action]
if !ok {
return fmt.Sprintf("Unknown action %q for /%s.\n\n%s", action, g.Name, g.Usage())
}
usage, summary := splitUsage(sub.Usage)
var b strings.Builder
fmt.Fprintf(&b, "/%s %s\n", g.Name, sub.Name)
if summary != "" {
fmt.Fprintf(&b, "- Summary: %s\n", summary)
}
if usage == "" {
usage = sub.Name
}
fmt.Fprintf(&b, "- Usage: /%s %s\n", g.Name, usage)
if sub.IsWrite {
b.WriteString("- Access: owner only\n")
}
fmt.Fprintf(&b, "- Tip: use /help %s to view sibling actions.", g.Name)
return strings.TrimRight(b.String(), "\n")
}
// Registry holds all registered command groups. // Registry holds all registered command groups.
type Registry struct { type Registry struct {
groups map[string]*CommandGroup groups map[string]*CommandGroup
@@ -83,12 +114,47 @@ func (r *Registry) GlobalHelp() string {
b.WriteString("/help - Show this help message\n") b.WriteString("/help - Show this help message\n")
b.WriteString("/new - Start a new conversation (resets session context)\n") b.WriteString("/new - Start a new conversation (resets session context)\n")
b.WriteString("/stop - Stop the current generation\n\n") b.WriteString("/stop - Stop the current generation\n\n")
for i, name := range r.order { for _, name := range r.order {
if i > 0 {
b.WriteByte('\n')
}
group := r.groups[name] group := r.groups[name]
b.WriteString(group.Usage()) fmt.Fprintf(&b, "- /%s - %s\n", group.Name, group.Description)
} }
b.WriteString("\nUse /help <group> to view actions, e.g. /help model")
return strings.TrimRight(b.String(), "\n") return strings.TrimRight(b.String(), "\n")
} }
func (r *Registry) GroupHelp(name string) string {
group, ok := r.groups[name]
if !ok {
return fmt.Sprintf("Unknown command group: /%s\n\n%s", name, r.GlobalHelp())
}
return group.Usage()
}
func (r *Registry) ActionHelp(groupName, action string) string {
group, ok := r.groups[groupName]
if !ok {
return fmt.Sprintf("Unknown command group: /%s\n\n%s", groupName, r.GlobalHelp())
}
return group.ActionHelp(action)
}
func splitUsage(usage string) (commandUsage string, summary string) {
usage = strings.TrimSpace(usage)
if usage == "" {
return "", ""
}
parts := strings.SplitN(usage, " - ", 2)
commandUsage = strings.TrimSpace(parts[0])
if len(parts) > 1 {
summary = strings.TrimSpace(parts[1])
}
return commandUsage, summary
}
func subSummary(sub SubCommand) string {
usage, summary := splitUsage(sub.Usage)
if summary == "" {
return usage
}
return fmt.Sprintf("%s - %s", sub.Name, summary)
}
+39 -6
View File
@@ -13,6 +13,9 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
Name: "list", Name: "list",
Usage: "list - List all search providers", Usage: "list - List all search providers",
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if h.searchProvService == nil {
return "Search provider service is not available.", nil
}
items, err := h.searchProvService.List(cc.Ctx, "") items, err := h.searchProvService.List(cc.Ctx, "")
if err != nil { if err != nil {
return "", err return "", err
@@ -20,14 +23,40 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
if len(items) == 0 { if len(items) == 0 {
return "No search providers found.", nil return "No search providers found.", nil
} }
records := make([][]kv, 0, len(items)) settingsResp, _ := h.getBotSettings(cc)
currentRecords := make([][]kv, 0, 1)
otherRecords := make([][]kv, 0, len(items))
for _, item := range items { for _, item := range items {
records = append(records, []kv{ label := item.Name
{"Name", item.Name}, record := []kv{
{"Name", label},
{"Provider", item.Provider}, {"Provider", item.Provider},
}) }
if item.ID == settingsResp.SearchProviderID {
label += " [current]"
record[0].value = label
currentRecords = append(currentRecords, record)
continue
}
otherRecords = append(otherRecords, record)
} }
return formatItems(records), nil currentRecords = append(currentRecords, otherRecords...)
records := currentRecords
return formatLimitedItems(records, defaultListLimit, "Use /search current to inspect the active provider."), nil
},
})
g.Register(SubCommand{
Name: "current",
Usage: "current - Show the current search provider",
Handler: func(cc CommandContext) (string, error) {
if h.settingsService == nil {
return "Settings service is not available.", nil
}
settingsResp, err := h.getBotSettings(cc)
if err != nil {
return "", err
}
return formatKV([]kv{{"Search Provider", h.resolveSearchProviderName(cc, settingsResp.SearchProviderID)}}), nil
}, },
}) })
g.Register(SubCommand{ g.Register(SubCommand{
@@ -38,7 +67,11 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
if len(cc.Args) < 1 { if len(cc.Args) < 1 {
return "Usage: /search set <name>", nil return "Usage: /search set <name>", nil
} }
if h.settingsService == nil {
return "Settings service is not available.", nil
}
name := cc.Args[0] name := cc.Args[0]
before, _ := h.getBotSettings(cc)
items, err := h.searchProvService.List(cc.Ctx, "") items, err := h.searchProvService.List(cc.Ctx, "")
if err != nil { if err != nil {
return "", err return "", err
@@ -51,7 +84,7 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("Search provider set to %q.", item.Name), nil return formatChangedValue("Search provider", h.resolveSearchProviderName(cc, before.SearchProviderID), item.Name), nil
} }
} }
return fmt.Sprintf("Search provider %q not found.", name), nil return fmt.Sprintf("Search provider %q not found.", name), nil
+8
View File
@@ -1,6 +1,7 @@
package command package command
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -106,6 +107,13 @@ func settingsUpdateUsage() string {
"- --heartbeat_model_id <id>" "- --heartbeat_model_id <id>"
} }
func (h *Handler) getBotSettings(cc CommandContext) (settings.Settings, error) {
if h.settingsService == nil {
return settings.Settings{}, errors.New("settings service is not available")
}
return h.settingsService.GetBot(cc.Ctx, cc.BotID)
}
// resolveModelName resolves a model UUID to "model_name (provider_name)". // resolveModelName resolves a model UUID to "model_name (provider_name)".
func (h *Handler) resolveModelName(cc CommandContext, modelID string) string { func (h *Handler) resolveModelName(cc CommandContext, modelID string) string {
if modelID == "" { if modelID == "" {
+94 -57
View File
@@ -6,7 +6,9 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
) )
func (h *Handler) buildStatusGroup() *CommandGroup { func (h *Handler) buildStatusGroup() *CommandGroup {
@@ -16,78 +18,113 @@ func (h *Handler) buildStatusGroup() *CommandGroup {
Name: "show", Name: "show",
Usage: "show - Show current session status", Usage: "show - Show current session status",
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if strings.TrimSpace(cc.SessionID) == "" {
return "No active session found for this conversation.", nil
}
return h.renderSessionStatus(cc, cc.SessionID, "current conversation")
},
})
g.Register(SubCommand{
Name: "latest",
Usage: "latest - Show the latest session status for this bot",
Handler: func(cc CommandContext) (string, error) {
if h.queries == nil {
return "Session info is not available.", nil
}
botUUID, err := parseBotUUID(cc.BotID) botUUID, err := parseBotUUID(cc.BotID)
if err != nil { if err != nil {
return "", err return "", err
} }
sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID) sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID)
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return "No active session found.", nil return "No session found for this bot.", nil
} }
return "", err return "", err
} }
return h.renderSessionStatus(cc, sessionID.String(), "latest bot session")
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, sessionID)
if err != nil {
return "", fmt.Errorf("count messages: %w", err)
}
var usedTokens int64
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, sessionID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("get usage: %w", err)
}
if err == nil {
usedTokens = latestUsage
}
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, sessionID)
if err != nil {
return "", fmt.Errorf("get cache: %w", err)
}
var contextWindowStr string
if h.settingsService != nil {
s, sErr := h.settingsService.GetBot(cc.Ctx, cc.BotID)
if sErr == nil && s.ChatModelID != "" && h.modelsService != nil {
m, mErr := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
if mErr == nil && m.Config.ContextWindow != nil {
contextWindowStr = formatTokens(int64(*m.Config.ContextWindow))
}
}
}
var cacheHitRate float64
if cacheRow.TotalInputTokens > 0 {
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
}
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, sessionID)
var b strings.Builder
b.WriteString("Session Status:\n\n")
fmt.Fprintf(&b, "- Messages: %d\n", msgCount)
if contextWindowStr != "" {
fmt.Fprintf(&b, "- Context: %s / %s\n", formatTokens(usedTokens), contextWindowStr)
} else {
fmt.Fprintf(&b, "- Context: %s\n", formatTokens(usedTokens))
}
fmt.Fprintf(&b, "- Cache Hit Rate: %.1f%%\n", cacheHitRate)
fmt.Fprintf(&b, "- Cache Read: %s\n", formatTokens(cacheRow.CacheReadTokens))
fmt.Fprintf(&b, "- Cache Write: %s\n", formatTokens(cacheRow.CacheWriteTokens))
if len(skills) > 0 {
fmt.Fprintf(&b, "- Skills: %s\n", strings.Join(skills, ", "))
}
return strings.TrimRight(b.String(), "\n"), nil
}, },
}) })
return g return g
} }
func (h *Handler) renderSessionStatus(cc CommandContext, sessionID string, scope string) (string, error) {
if h.queries == nil {
return "Session info is not available.", nil
}
pgSessionID, err := parseCommandUUID(sessionID)
if err != nil {
return "", err
}
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, pgSessionID)
if err != nil {
return "", fmt.Errorf("count messages: %w", err)
}
var usedTokens int64
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, pgSessionID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return "", fmt.Errorf("get usage: %w", err)
}
if err == nil {
usedTokens = latestUsage
}
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, pgSessionID)
if err != nil {
return "", fmt.Errorf("get cache: %w", err)
}
var cacheHitRate float64
if cacheRow.TotalInputTokens > 0 {
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
}
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, pgSessionID)
contextUsage := formatTokens(usedTokens)
if contextWindow := h.resolveContextWindow(cc); contextWindow != "" {
contextUsage = contextUsage + " / " + contextWindow
}
pairs := []kv{
{"Scope", scope},
{"Session ID", sessionID},
{"Messages", strconv.FormatInt(msgCount, 10)},
{"Context", contextUsage},
{"Cache Hit Rate", fmt.Sprintf("%.1f%%", cacheHitRate)},
{"Cache Read", formatTokens(cacheRow.CacheReadTokens)},
{"Cache Write", formatTokens(cacheRow.CacheWriteTokens)},
}
if len(skills) > 0 {
pairs = append(pairs, kv{"Skills", strings.Join(skills, ", ")})
}
return formatKV(pairs), nil
}
func (h *Handler) resolveContextWindow(cc CommandContext) string {
if h.settingsService == nil || h.modelsService == nil {
return ""
}
s, err := h.settingsService.GetBot(cc.Ctx, cc.BotID)
if err != nil || s.ChatModelID == "" {
return ""
}
m, err := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
if err != nil || m.Config.ContextWindow == nil {
return ""
}
return formatTokens(int64(*m.Config.ContextWindow))
}
func parseCommandUUID(id string) (pgtype.UUID, error) {
parsed, err := uuid.Parse(strings.TrimSpace(id))
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid uuid: %w", err)
}
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
}
func formatTokens(n int64) string { func formatTokens(n int64) string {
if n >= 1_000_000 { if n >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(n)/1_000_000) return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
+6
View File
@@ -18,6 +18,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
Name: "summary", Name: "summary",
Usage: "summary - Token usage summary (last 7 days)", Usage: "summary - Token usage summary (last 7 days)",
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if h.queries == nil {
return "Usage info is not available.", nil
}
botUUID, err := parseBotUUID(cc.BotID) botUUID, err := parseBotUUID(cc.BotID)
if err != nil { if err != nil {
return "", err return "", err
@@ -89,6 +92,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
Name: "by-model", Name: "by-model",
Usage: "by-model - Token usage grouped by model", Usage: "by-model - Token usage grouped by model",
Handler: func(cc CommandContext) (string, error) { Handler: func(cc CommandContext) (string, error) {
if h.queries == nil {
return "Usage info is not available.", nil
}
botUUID, err := parseBotUUID(cc.BotID) botUUID, err := parseBotUUID(cc.BotID)
if err != nil { if err != nil {
return "", err return "", err