mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(command): improve slash command UX (#361)
Make slash commands easier to navigate in chat by splitting help into levels, compacting list output, and surfacing current selections for model, search, memory, and browser settings. Also route /status to the active conversation session and add an access inspector so users can understand their current command and ACL context.
This commit is contained in:
@@ -604,6 +604,7 @@ func provideChannelRouter(
|
||||
emailOutboxService,
|
||||
heartbeatService,
|
||||
queries,
|
||||
aclService,
|
||||
&commandSkillLoaderAdapter{handler: containerdHandler},
|
||||
&commandContainerFSAdapter{manager: manager},
|
||||
))
|
||||
@@ -779,6 +780,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
|
||||
if err != nil {
|
||||
return inbound.SessionResult{}, err
|
||||
}
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
||||
if err != nil {
|
||||
|
||||
@@ -492,6 +492,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc
|
||||
emailOutboxService,
|
||||
heartbeatService,
|
||||
queries,
|
||||
aclService,
|
||||
&commandSkillLoaderAdapter{handler: containerdHandler},
|
||||
&commandContainerFSAdapter{manager: manager},
|
||||
))
|
||||
@@ -877,6 +878,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
|
||||
if err != nil {
|
||||
return inbound.SessionResult{}, err
|
||||
}
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
||||
if err != nil {
|
||||
|
||||
@@ -71,6 +71,7 @@ type ttsModelResolver interface {
|
||||
// SessionEnsurer resolves or creates an active session for a route.
|
||||
type SessionEnsurer interface {
|
||||
EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error)
|
||||
GetActiveSession(ctx context.Context, routeID string) (SessionResult, error)
|
||||
// CreateNewSession always creates a fresh session and sets it as the
|
||||
// active session for the given route, replacing any previous one.
|
||||
// sessionType defaults to "chat" if empty.
|
||||
@@ -298,11 +299,23 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
if isStopCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
return p.handleStopCommand(ctx, cfg, msg, sender, identity)
|
||||
}
|
||||
if isStatusCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
return p.handleStatusCommand(ctx, cfg, msg, sender, identity)
|
||||
}
|
||||
|
||||
// Skip generic command handler for mode-prefix commands (/btw, /now, /next)
|
||||
// so they pass through to mode detection below.
|
||||
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText)
|
||||
reply, err := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
|
||||
BotID: strings.TrimSpace(identity.BotID),
|
||||
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
|
||||
UserID: strings.TrimSpace(identity.UserID),
|
||||
Text: cmdText,
|
||||
ChannelType: msg.Channel.String(),
|
||||
ConversationType: strings.TrimSpace(msg.Conversation.Type),
|
||||
ConversationID: strings.TrimSpace(msg.Conversation.ID),
|
||||
ThreadID: extractThreadID(msg),
|
||||
})
|
||||
if err != nil {
|
||||
reply = "Error: " + err.Error()
|
||||
}
|
||||
@@ -2489,6 +2502,18 @@ func isNewSessionCommand(cmdText string) bool {
|
||||
return parsed.Resource == "new"
|
||||
}
|
||||
|
||||
func isStatusCommand(cmdText string) bool {
|
||||
extracted := command.ExtractCommandText(cmdText)
|
||||
if extracted == "" {
|
||||
return false
|
||||
}
|
||||
parsed, err := command.Parse(extracted)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return parsed.Resource == "status"
|
||||
}
|
||||
|
||||
// resolveNewSessionType determines the session type for /new command.
|
||||
// /new chat → chat, /new discuss → discuss, /new (no arg) → default by context.
|
||||
// WebUI (local channel) always defaults to chat.
|
||||
@@ -2611,3 +2636,83 @@ func (p *ChannelInboundProcessor) handleNewSessionCommand(
|
||||
Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)},
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ChannelInboundProcessor) handleStatusCommand(
|
||||
ctx context.Context,
|
||||
cfg channel.ChannelConfig,
|
||||
msg channel.InboundMessage,
|
||||
sender channel.StreamReplySender,
|
||||
identity InboundIdentity,
|
||||
) error {
|
||||
target := strings.TrimSpace(msg.ReplyTarget)
|
||||
if target == "" {
|
||||
return errors.New("reply target missing for /status command")
|
||||
}
|
||||
if p.routeResolver == nil {
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: "Error: route resolver not configured."},
|
||||
})
|
||||
}
|
||||
if p.commandHandler == nil {
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: "Error: command handler not configured."},
|
||||
})
|
||||
}
|
||||
|
||||
threadID := extractThreadID(msg)
|
||||
routeMetadata := buildRouteMetadata(msg, identity)
|
||||
p.enrichConversationAvatar(ctx, cfg, msg, routeMetadata)
|
||||
resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{
|
||||
BotID: identity.BotID,
|
||||
Platform: msg.Channel.String(),
|
||||
ConversationID: msg.Conversation.ID,
|
||||
ThreadID: threadID,
|
||||
ConversationType: msg.Conversation.Type,
|
||||
ChannelIdentityID: identity.UserID,
|
||||
ChannelConfigID: identity.ChannelConfigID,
|
||||
ReplyTarget: target,
|
||||
Metadata: routeMetadata,
|
||||
})
|
||||
if err != nil {
|
||||
if p.logger != nil {
|
||||
p.logger.Warn("resolve route for /status command failed", slog.Any("error", err))
|
||||
}
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: "Error: failed to resolve conversation route."},
|
||||
})
|
||||
}
|
||||
|
||||
sessionID := ""
|
||||
if p.sessionEnsurer != nil {
|
||||
sess, sessErr := p.sessionEnsurer.GetActiveSession(ctx, resolved.RouteID)
|
||||
if sessErr == nil {
|
||||
sessionID = strings.TrimSpace(sess.ID)
|
||||
} else if p.logger != nil {
|
||||
p.logger.Debug("resolve active session for /status command failed", slog.Any("error", sessErr))
|
||||
}
|
||||
}
|
||||
|
||||
reply, execErr := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
|
||||
BotID: strings.TrimSpace(identity.BotID),
|
||||
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
|
||||
UserID: strings.TrimSpace(identity.UserID),
|
||||
Text: rawTextForCommand(msg, strings.TrimSpace(msg.Message.PlainText())),
|
||||
ChannelType: msg.Channel.String(),
|
||||
ConversationType: strings.TrimSpace(msg.Conversation.Type),
|
||||
ConversationID: strings.TrimSpace(msg.Conversation.ID),
|
||||
ThreadID: threadID,
|
||||
RouteID: strings.TrimSpace(resolved.RouteID),
|
||||
SessionID: sessionID,
|
||||
})
|
||||
if execErr != nil {
|
||||
reply = "Error: " + execErr.Error()
|
||||
}
|
||||
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: reply},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,11 +13,15 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/acl"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/channel/identities"
|
||||
"github.com/memohai/memoh/internal/channel/route"
|
||||
"github.com/memohai/memoh/internal/command"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/media"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
@@ -185,6 +189,71 @@ type fakeChatACL struct {
|
||||
lastReq acl.EvaluateRequest
|
||||
}
|
||||
|
||||
type fakeSessionEnsurer struct {
|
||||
activeSession SessionResult
|
||||
activeErr error
|
||||
lastRouteID string
|
||||
}
|
||||
|
||||
func (f *fakeSessionEnsurer) EnsureActiveSession(_ context.Context, _, routeID, _ string) (SessionResult, error) {
|
||||
f.lastRouteID = routeID
|
||||
if f.activeErr != nil {
|
||||
return SessionResult{}, f.activeErr
|
||||
}
|
||||
return f.activeSession, nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionEnsurer) GetActiveSession(_ context.Context, routeID string) (SessionResult, error) {
|
||||
f.lastRouteID = routeID
|
||||
if f.activeErr != nil {
|
||||
return SessionResult{}, f.activeErr
|
||||
}
|
||||
return f.activeSession, nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionEnsurer) CreateNewSession(_ context.Context, _, routeID, _, _ string) (SessionResult, error) {
|
||||
f.lastRouteID = routeID
|
||||
if f.activeErr != nil {
|
||||
return SessionResult{}, f.activeErr
|
||||
}
|
||||
return f.activeSession, nil
|
||||
}
|
||||
|
||||
type fakeCommandQueries struct {
|
||||
messageCount int64
|
||||
usage int64
|
||||
cacheRow dbsqlc.GetSessionCacheStatsRow
|
||||
skills []string
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
|
||||
return pgtype.UUID{}, errors.New("unexpected latest session lookup")
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
return f.messageCount, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
return f.usage, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
|
||||
return f.cacheRow, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
|
||||
return f.skills, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) {
|
||||
f.calls++
|
||||
f.lastReq = req
|
||||
@@ -548,6 +617,72 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelInboundProcessorStatusUsesRouteSession(t *testing.T) {
|
||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-status"}}
|
||||
policySvc := &fakePolicyService{}
|
||||
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-status", RouteID: "route-status"}}
|
||||
gateway := &fakeChatGateway{}
|
||||
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
|
||||
processor.SetSessionEnsurer(&fakeSessionEnsurer{
|
||||
activeSession: SessionResult{ID: "11111111-1111-1111-1111-111111111111", Type: "chat"},
|
||||
})
|
||||
processor.SetCommandHandler(command.NewHandler(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&fakeCommandQueries{
|
||||
messageCount: 9,
|
||||
usage: 512,
|
||||
cacheRow: dbsqlc.GetSessionCacheStatsRow{
|
||||
CacheReadTokens: 64,
|
||||
CacheWriteTokens: 32,
|
||||
TotalInputTokens: 512,
|
||||
},
|
||||
skills: []string{"search"},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
))
|
||||
sender := &fakeReplySender{}
|
||||
|
||||
cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("discord")}
|
||||
msg := channel.InboundMessage{
|
||||
BotID: "bot-1",
|
||||
Channel: channel.ChannelType("discord"),
|
||||
Message: channel.Message{Text: "/status"},
|
||||
ReplyTarget: "discord:status",
|
||||
Sender: channel.Identity{SubjectID: "user-1"},
|
||||
Conversation: channel.Conversation{
|
||||
ID: "conv-status",
|
||||
Type: channel.ConversationTypePrivate,
|
||||
},
|
||||
}
|
||||
|
||||
if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one status reply, got %d", len(sender.sent))
|
||||
}
|
||||
if !strings.Contains(sender.sent[0].Message.Text, "- Scope: current conversation") {
|
||||
t.Fatalf("expected current conversation scope, got %q", sender.sent[0].Message.Text)
|
||||
}
|
||||
if !strings.Contains(sender.sent[0].Message.Text, "- Session ID: 11111111-1111-1111-1111-111111111111") {
|
||||
t.Fatalf("expected active route session in reply, got %q", sender.sent[0].Message.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/acl"
|
||||
)
|
||||
|
||||
func (h *Handler) buildAccessGroup() *CommandGroup {
|
||||
g := newCommandGroup("access", "Inspect identity and permission context")
|
||||
g.DefaultAction = "show"
|
||||
g.Register(SubCommand{
|
||||
Name: "show",
|
||||
Usage: "show - Show current identity, write access, and chat ACL context",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
writeAccess := "no"
|
||||
if cc.Role == "owner" {
|
||||
writeAccess = "yes"
|
||||
}
|
||||
|
||||
pairs := []kv{
|
||||
{"Channel Identity", fallbackValue(cc.ChannelIdentityID)},
|
||||
{"Linked User", fallbackValue(cc.UserID)},
|
||||
{"Bot Role", fallbackValue(cc.Role)},
|
||||
{"Write Commands", writeAccess},
|
||||
{"Channel", fallbackValue(cc.ChannelType)},
|
||||
{"Conversation Type", fallbackValue(cc.ConversationType)},
|
||||
{"Conversation ID", fallbackValue(cc.ConversationID)},
|
||||
{"Thread ID", fallbackValue(cc.ThreadID)},
|
||||
}
|
||||
if strings.TrimSpace(cc.RouteID) != "" {
|
||||
pairs = append(pairs, kv{"Route ID", cc.RouteID})
|
||||
}
|
||||
if strings.TrimSpace(cc.SessionID) != "" {
|
||||
pairs = append(pairs, kv{"Session ID", cc.SessionID})
|
||||
}
|
||||
|
||||
aclStatus := "unavailable"
|
||||
if h.aclEvaluator != nil && strings.TrimSpace(cc.ChannelType) != "" {
|
||||
allowed, err := h.aclEvaluator.Evaluate(cc.Ctx, acl.EvaluateRequest{
|
||||
BotID: cc.BotID,
|
||||
ChannelIdentityID: cc.ChannelIdentityID,
|
||||
ChannelType: cc.ChannelType,
|
||||
SourceScope: acl.SourceScope{
|
||||
ConversationType: cc.ConversationType,
|
||||
ConversationID: cc.ConversationID,
|
||||
ThreadID: cc.ThreadID,
|
||||
},
|
||||
})
|
||||
switch {
|
||||
case err != nil:
|
||||
aclStatus = "error: " + err.Error()
|
||||
case allowed:
|
||||
aclStatus = "allow"
|
||||
default:
|
||||
aclStatus = "deny"
|
||||
}
|
||||
}
|
||||
pairs = append(pairs, kv{"Chat ACL", aclStatus})
|
||||
|
||||
return formatKV(pairs), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
}
|
||||
|
||||
func fallbackValue(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "(none)"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func formatChangedValue(label, before, after string) string {
|
||||
return fmt.Sprintf("%s: %s -> %s", label, fallbackValue(before), fallbackValue(after))
|
||||
}
|
||||
@@ -13,6 +13,9 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
Name: "list",
|
||||
Usage: "list - List all browser contexts",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.browserCtxService == nil {
|
||||
return "Browser context service is not available.", nil
|
||||
}
|
||||
items, err := h.browserCtxService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -20,13 +23,37 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
if len(items) == 0 {
|
||||
return "No browser contexts found.", nil
|
||||
}
|
||||
records := make([][]kv, 0, len(items))
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
currentRecords := make([][]kv, 0, 1)
|
||||
otherRecords := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
records = append(records, []kv{
|
||||
{"Name", item.Name},
|
||||
})
|
||||
label := item.Name
|
||||
record := []kv{{"Name", label}}
|
||||
if item.ID == settingsResp.BrowserContextID {
|
||||
label += " [current]"
|
||||
record[0].value = label
|
||||
currentRecords = append(currentRecords, record)
|
||||
continue
|
||||
}
|
||||
otherRecords = append(otherRecords, record)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
currentRecords = append(currentRecords, otherRecords...)
|
||||
records := currentRecords
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /browser current to inspect the active context."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show the current browser context",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{{"Browser Context", h.resolveBrowserContextName(cc, settingsResp.BrowserContextID)}}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -37,7 +64,11 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /browser set <name>", nil
|
||||
}
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
name := cc.Args[0]
|
||||
before, _ := h.getBotSettings(cc)
|
||||
items, err := h.browserCtxService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -50,7 +81,7 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Browser context set to %q.", item.Name), nil
|
||||
return formatChangedValue("Browser context", h.resolveBrowserContextName(cc, before.BrowserContextID), item.Name), nil
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Browser context %q not found.", name), nil
|
||||
|
||||
@@ -16,5 +16,6 @@ func (h *Handler) buildRegistry() *Registry {
|
||||
r.RegisterGroup(h.buildSkillGroup())
|
||||
r.RegisterGroup(h.buildFSGroup())
|
||||
r.RegisterGroup(h.buildStatusGroup())
|
||||
r.RegisterGroup(h.buildAccessGroup())
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
||||
{"Provider", item.Provider},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /email bindings to inspect bot bindings."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -46,7 +46,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
||||
{"Permissions", perms},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /email outbox to inspect recent sends."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -70,7 +70,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
||||
{"Sent", item.SentAt.Format("01-02 15:04")},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, 10, "Use the Web UI for older outbox entries."), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
|
||||
@@ -6,19 +6,10 @@ import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const defaultListLimit = 12
|
||||
|
||||
// formatItems renders a list of records as a Markdown-style list.
|
||||
// Each record is a slice of kv pairs; the first pair's value is used as the
|
||||
// bullet title, and subsequent pairs are indented beneath it.
|
||||
//
|
||||
// Example output:
|
||||
//
|
||||
// - mybot
|
||||
// Description: A helpful assistant
|
||||
// ID: abc123
|
||||
//
|
||||
// - another
|
||||
// Description: Something else
|
||||
// ID: def456
|
||||
// Each record is rendered on a single line so long lists stay readable in IM.
|
||||
func formatItems(items [][]kv) string {
|
||||
if len(items) == 0 {
|
||||
return ""
|
||||
@@ -31,14 +22,41 @@ func formatItems(items [][]kv) string {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
fmt.Fprintf(&b, "- %s\n", record[0].value)
|
||||
fmt.Fprintf(&b, "- %s", record[0].value)
|
||||
extras := make([]string, 0, len(record)-1)
|
||||
for _, pair := range record[1:] {
|
||||
fmt.Fprintf(&b, " %s: %s\n", pair.key, pair.value)
|
||||
if strings.TrimSpace(pair.value) == "" {
|
||||
continue
|
||||
}
|
||||
extras = append(extras, fmt.Sprintf("%s: %s", pair.key, pair.value))
|
||||
}
|
||||
if len(extras) > 0 {
|
||||
fmt.Fprintf(&b, " | %s", strings.Join(extras, " | "))
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatLimitedItems(items [][]kv, limit int, hint string) string {
|
||||
if len(items) == 0 {
|
||||
return ""
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = defaultListLimit
|
||||
}
|
||||
total := len(items)
|
||||
if total <= limit {
|
||||
return formatItems(items)
|
||||
}
|
||||
shown := items[:limit]
|
||||
result := formatItems(shown)
|
||||
suffix := fmt.Sprintf("Showing %d of %d items.", len(shown), total)
|
||||
if strings.TrimSpace(hint) != "" {
|
||||
suffix += " " + strings.TrimSpace(hint)
|
||||
}
|
||||
return result + "\n\n" + suffix
|
||||
}
|
||||
|
||||
// formatKV renders key-value pairs as a simple Markdown list.
|
||||
//
|
||||
// Example output:
|
||||
|
||||
+57
-13
@@ -4,10 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/browsercontexts"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
emailpkg "github.com/memohai/memoh/internal/email"
|
||||
"github.com/memohai/memoh/internal/heartbeat"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
@@ -56,13 +56,28 @@ type Handler struct {
|
||||
emailService *emailpkg.Service
|
||||
emailOutboxService *emailpkg.OutboxService
|
||||
heartbeatService *heartbeat.Service
|
||||
queries *dbsqlc.Queries
|
||||
queries CommandQueries
|
||||
aclEvaluator AccessEvaluator
|
||||
skillLoader SkillLoader
|
||||
containerFS ContainerFS
|
||||
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// ExecuteInput carries the caller identity and channel context for command execution.
|
||||
type ExecuteInput struct {
|
||||
BotID string
|
||||
ChannelIdentityID string
|
||||
UserID string
|
||||
Text string
|
||||
ChannelType string
|
||||
ConversationType string
|
||||
ConversationID string
|
||||
ThreadID string
|
||||
RouteID string
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with all required services.
|
||||
func NewHandler(
|
||||
log *slog.Logger,
|
||||
@@ -78,7 +93,8 @@ func NewHandler(
|
||||
emailService *emailpkg.Service,
|
||||
emailOutboxService *emailpkg.OutboxService,
|
||||
heartbeatService *heartbeat.Service,
|
||||
queries *dbsqlc.Queries,
|
||||
queries CommandQueries,
|
||||
aclEvaluator AccessEvaluator,
|
||||
skillLoader SkillLoader,
|
||||
containerFS ContainerFS,
|
||||
) *Handler {
|
||||
@@ -99,6 +115,7 @@ func NewHandler(
|
||||
emailOutboxService: emailOutboxService,
|
||||
heartbeatService: heartbeatService,
|
||||
queries: queries,
|
||||
aclEvaluator: aclEvaluator,
|
||||
skillLoader: skillLoader,
|
||||
containerFS: containerFS,
|
||||
logger: log.With(slog.String("component", "command")),
|
||||
@@ -140,7 +157,16 @@ func (h *Handler) IsCommand(text string) bool {
|
||||
|
||||
// Execute parses and runs a slash command, returning the text reply.
|
||||
func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) {
|
||||
cmdText := ExtractCommandText(text)
|
||||
return h.ExecuteWithInput(ctx, ExecuteInput{
|
||||
BotID: botID,
|
||||
ChannelIdentityID: channelIdentityID,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
// ExecuteWithInput parses and runs a slash command with channel/session context.
|
||||
func (h *Handler) ExecuteWithInput(ctx context.Context, input ExecuteInput) (string, error) {
|
||||
cmdText := ExtractCommandText(input.Text)
|
||||
if cmdText == "" {
|
||||
return h.registry.GlobalHelp(), nil
|
||||
}
|
||||
@@ -151,12 +177,16 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
|
||||
|
||||
// Resolve the user's role in this bot.
|
||||
role := ""
|
||||
if h.roleResolver != nil && channelIdentityID != "" {
|
||||
r, err := h.roleResolver.GetMemberRole(ctx, botID, channelIdentityID)
|
||||
roleIdentityID := input.ChannelIdentityID
|
||||
if strings.TrimSpace(input.UserID) != "" {
|
||||
roleIdentityID = strings.TrimSpace(input.UserID)
|
||||
}
|
||||
if h.roleResolver != nil && roleIdentityID != "" {
|
||||
r, err := h.roleResolver.GetMemberRole(ctx, input.BotID, roleIdentityID)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to resolve member role",
|
||||
slog.String("bot_id", botID),
|
||||
slog.String("channel_identity_id", channelIdentityID),
|
||||
slog.String("bot_id", input.BotID),
|
||||
slog.String("role_identity_id", roleIdentityID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
} else {
|
||||
@@ -165,15 +195,29 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
|
||||
}
|
||||
|
||||
cc := CommandContext{
|
||||
Ctx: ctx,
|
||||
BotID: botID,
|
||||
Role: role,
|
||||
Args: parsed.Args,
|
||||
Ctx: ctx,
|
||||
BotID: input.BotID,
|
||||
Role: role,
|
||||
Args: parsed.Args,
|
||||
ChannelIdentityID: strings.TrimSpace(input.ChannelIdentityID),
|
||||
UserID: strings.TrimSpace(input.UserID),
|
||||
ChannelType: strings.TrimSpace(input.ChannelType),
|
||||
ConversationType: strings.TrimSpace(input.ConversationType),
|
||||
ConversationID: strings.TrimSpace(input.ConversationID),
|
||||
ThreadID: strings.TrimSpace(input.ThreadID),
|
||||
RouteID: strings.TrimSpace(input.RouteID),
|
||||
SessionID: strings.TrimSpace(input.SessionID),
|
||||
}
|
||||
|
||||
// /help
|
||||
if parsed.Resource == "help" {
|
||||
return h.registry.GlobalHelp(), nil
|
||||
if parsed.Action == "" {
|
||||
return h.registry.GlobalHelp(), nil
|
||||
}
|
||||
if len(parsed.Args) == 0 {
|
||||
return h.registry.GroupHelp(parsed.Action), nil
|
||||
}
|
||||
return h.registry.ActionHelp(parsed.Action, parsed.Args[0]), nil
|
||||
}
|
||||
|
||||
// Top-level commands (e.g. /new) are handled by the channel inbound
|
||||
|
||||
@@ -5,6 +5,10 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
@@ -25,9 +29,58 @@ type fakeScheduleService struct {
|
||||
items []schedule.Schedule
|
||||
}
|
||||
|
||||
type fakeCommandQueries struct {
|
||||
latestSessionID pgtype.UUID
|
||||
latestSessionErr error
|
||||
messageCount int64
|
||||
latestUsage int64
|
||||
latestUsageErr error
|
||||
cacheRow dbsqlc.GetSessionCacheStatsRow
|
||||
cacheErr error
|
||||
skills []string
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
|
||||
return f.latestSessionID, f.latestSessionErr
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
return f.messageCount, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
if f.latestUsageErr != nil {
|
||||
return 0, f.latestUsageErr
|
||||
}
|
||||
return f.latestUsage, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
|
||||
if f.cacheErr != nil {
|
||||
return dbsqlc.GetSessionCacheStatsRow{}, f.cacheErr
|
||||
}
|
||||
return f.cacheRow, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
|
||||
return f.skills, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// newTestHandler creates a Handler with nil services for use in tests.
|
||||
func newTestHandler(roleResolver MemberRoleResolver) *Handler {
|
||||
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func newTestHandlerWithQueries(roleResolver MemberRoleResolver, queries CommandQueries) *Handler {
|
||||
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, queries, nil, nil, nil)
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
@@ -73,6 +126,42 @@ func TestExecute_Help(t *testing.T) {
|
||||
if !strings.Contains(result, "Available commands") {
|
||||
t.Errorf("expected help text, got: %s", result)
|
||||
}
|
||||
if strings.Contains(result, "set-heartbeat") {
|
||||
t.Errorf("top-level help should not expand nested actions, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- /model - Manage bot models") {
|
||||
t.Errorf("expected top-level model entry, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_HelpGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "/model - Manage bot models") {
|
||||
t.Errorf("expected group help, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- set - Set the chat model [owner]") {
|
||||
t.Errorf("expected compact action summary, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_HelpAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model set")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "Usage: /model set <model_id> | <provider_name> <model_name>") {
|
||||
t.Errorf("expected action usage, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "Access: owner only") {
|
||||
t.Errorf("expected owner hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_UnknownCommand(t *testing.T) {
|
||||
@@ -207,8 +296,8 @@ func TestFormatItems(t *testing.T) {
|
||||
if !strings.Contains(result, "- foo") {
|
||||
t.Errorf("expected '- foo' bullet, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, " Type: bar") {
|
||||
t.Errorf("expected indented 'Type: bar', got: %s", result)
|
||||
if !strings.Contains(result, "- foo | Type: bar") {
|
||||
t.Errorf("expected compact line entry, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- longname") {
|
||||
t.Errorf("expected '- longname' bullet, got: %s", result)
|
||||
@@ -255,7 +344,7 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
|
||||
for _, group := range []string{
|
||||
"schedule", "mcp", "settings",
|
||||
"model", "memory", "search", "browser", "usage",
|
||||
"email", "heartbeat", "skill", "fs",
|
||||
"email", "heartbeat", "skill", "fs", "access",
|
||||
} {
|
||||
if !strings.Contains(help, "/"+group) {
|
||||
t.Errorf("missing /%s in global help", group)
|
||||
@@ -263,6 +352,86 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithInput_Access(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||
result, err := h.ExecuteWithInput(context.Background(), ExecuteInput{
|
||||
BotID: "bot-1",
|
||||
ChannelIdentityID: "channel-id-1",
|
||||
UserID: "user-id-1",
|
||||
Text: "/access",
|
||||
ChannelType: "discord",
|
||||
ConversationType: "thread",
|
||||
ConversationID: "conv-1",
|
||||
ThreadID: "thread-1",
|
||||
RouteID: "route-1",
|
||||
SessionID: "session-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "- Channel Identity: channel-id-1") {
|
||||
t.Errorf("expected channel identity in access output, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- Write Commands: yes") {
|
||||
t.Errorf("expected write access in access output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_StatusLatest(t *testing.T) {
|
||||
t.Parallel()
|
||||
sessionUUID := pgtype.UUID{}
|
||||
copy(sessionUUID.Bytes[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
sessionUUID.Valid = true
|
||||
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
|
||||
latestSessionID: sessionUUID,
|
||||
messageCount: 42,
|
||||
latestUsage: 1200,
|
||||
cacheRow: dbsqlc.GetSessionCacheStatsRow{
|
||||
CacheReadTokens: 300,
|
||||
CacheWriteTokens: 150,
|
||||
TotalInputTokens: 1200,
|
||||
},
|
||||
skills: []string{"search", "browser"},
|
||||
})
|
||||
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "- Scope: latest bot session") {
|
||||
t.Errorf("expected latest scope, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- Messages: 42") {
|
||||
t.Errorf("expected message count, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_StatusLatestNoRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
|
||||
latestSessionErr: pgx.ErrNoRows,
|
||||
})
|
||||
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "No session found for this bot.") {
|
||||
t.Errorf("expected no session message, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_StatusShowWithoutSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{})
|
||||
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/status")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "No active session found for this conversation.") {
|
||||
t.Errorf("expected route-aware no session message, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify write commands are tagged with [owner] in usage.
|
||||
func TestUsage_OwnerTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h *Handler) buildHeartbeatGroup() *CommandGroup {
|
||||
}
|
||||
records = append(records, rec)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, 10, "Use the Web UI for older heartbeat logs."), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package command
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/acl"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
// Skill represents a single skill loaded from a bot's container.
|
||||
type Skill struct {
|
||||
@@ -25,3 +32,20 @@ type ContainerFS interface {
|
||||
ListDir(ctx context.Context, botID, path string) ([]FSEntry, error)
|
||||
ReadFile(ctx context.Context, botID, path string) (string, error)
|
||||
}
|
||||
|
||||
// CommandQueries captures the sqlc methods used by slash commands.
|
||||
// *dbsqlc.Queries satisfies this interface directly.
|
||||
type CommandQueries interface {
|
||||
GetLatestSessionIDByBot(ctx context.Context, botID pgtype.UUID) (pgtype.UUID, error)
|
||||
CountMessagesBySession(ctx context.Context, sessionID pgtype.UUID) (int64, error)
|
||||
GetLatestAssistantUsage(ctx context.Context, sessionID pgtype.UUID) (int64, error)
|
||||
GetSessionCacheStats(ctx context.Context, sessionID pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error)
|
||||
GetSessionUsedSkills(ctx context.Context, sessionID pgtype.UUID) ([]string, error)
|
||||
GetTokenUsageByDayAndType(ctx context.Context, arg dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error)
|
||||
GetTokenUsageByModel(ctx context.Context, arg dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error)
|
||||
}
|
||||
|
||||
// AccessEvaluator checks whether the current channel context may trigger chat.
|
||||
type AccessEvaluator interface {
|
||||
Evaluate(ctx context.Context, req acl.EvaluateRequest) (bool, error)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func (h *Handler) buildMCPGroup() *CommandGroup {
|
||||
{"Status", item.Status},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /mcp get <name> for full details."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
|
||||
@@ -13,6 +13,9 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
Name: "list",
|
||||
Usage: "list - List all memory providers",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.memProvService == nil {
|
||||
return "Memory provider service is not available.", nil
|
||||
}
|
||||
items, err := h.memProvService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -20,18 +23,44 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
if len(items) == 0 {
|
||||
return "No memory providers found.", nil
|
||||
}
|
||||
records := make([][]kv, 0, len(items))
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
currentRecords := make([][]kv, 0, 1)
|
||||
otherRecords := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
def := ""
|
||||
if item.IsDefault {
|
||||
def = " (default)"
|
||||
}
|
||||
records = append(records, []kv{
|
||||
{"Name", item.Name + def},
|
||||
label := item.Name + def
|
||||
record := []kv{
|
||||
{"Name", label},
|
||||
{"Provider", item.Provider},
|
||||
})
|
||||
}
|
||||
if item.ID == settingsResp.MemoryProviderID {
|
||||
label += " [current]"
|
||||
record[0].value = label
|
||||
currentRecords = append(currentRecords, record)
|
||||
continue
|
||||
}
|
||||
otherRecords = append(otherRecords, record)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
currentRecords = append(currentRecords, otherRecords...)
|
||||
records := currentRecords
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /memory current to inspect the active provider."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show the current memory provider",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{{"Memory Provider", h.resolveMemoryProviderName(cc, settingsResp.MemoryProviderID)}}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -42,7 +71,11 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /memory set <name>", nil
|
||||
}
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
name := cc.Args[0]
|
||||
before, _ := h.getBotSettings(cc)
|
||||
items, err := h.memProvService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -55,7 +88,7 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Memory provider set to %q.", item.Name), nil
|
||||
return formatChangedValue("Memory provider", h.resolveMemoryProviderName(cc, before.MemoryProviderID), item.Name), nil
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Memory provider %q not found.", name), nil
|
||||
|
||||
+147
-13
@@ -1,7 +1,9 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
@@ -12,36 +14,81 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
||||
g := newCommandGroup("model", "Manage bot models")
|
||||
g.Register(SubCommand{
|
||||
Name: "list",
|
||||
Usage: "list - List all available chat models",
|
||||
Usage: "list [provider_name] - List available chat models",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.modelsService == nil {
|
||||
return "Model service is not available.", nil
|
||||
}
|
||||
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
filterProvider := ""
|
||||
if len(cc.Args) > 0 {
|
||||
filterProvider = strings.TrimSpace(strings.Join(cc.Args, " "))
|
||||
}
|
||||
items = h.filterModelsByProvider(cc, items, filterProvider)
|
||||
if len(items) == 0 {
|
||||
if filterProvider != "" {
|
||||
return fmt.Sprintf("No chat models found for provider %q.", filterProvider), nil
|
||||
}
|
||||
return "No chat models found.", nil
|
||||
}
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
sort.SliceStable(items, func(i, j int) bool {
|
||||
return modelSortRank(items[i], settingsResp) < modelSortRank(items[j], settingsResp)
|
||||
})
|
||||
records := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
provName := h.resolveProviderName(cc, item.ProviderID)
|
||||
label := item.Name
|
||||
markers := modelMarkers(item.ID, settingsResp)
|
||||
if len(markers) > 0 {
|
||||
label += " [" + strings.Join(markers, ", ") + "]"
|
||||
}
|
||||
records = append(records, []kv{
|
||||
{"Model", item.Name},
|
||||
{"Model", label},
|
||||
{"Provider", provName},
|
||||
{"Model ID", item.ModelID},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
hint := "Use /model current to inspect active selections."
|
||||
if filterProvider == "" {
|
||||
hint = "Use /model list <provider_name> to narrow results."
|
||||
}
|
||||
return formatLimitedItems(records, defaultListLimit, hint), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show current chat and heartbeat models",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{
|
||||
{"Chat Model", h.resolveModelName(cc, settingsResp.ChatModelID)},
|
||||
{"Heartbeat Model", h.resolveModelName(cc, settingsResp.HeartbeatModelID)},
|
||||
}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "set",
|
||||
Usage: "set <provider_name> <model_name> - Set the chat model",
|
||||
Usage: "set <model_id> | <provider_name> <model_name> - Set the chat model",
|
||||
IsWrite: true,
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if len(cc.Args) < 2 {
|
||||
return "Usage: /model set <provider_name> <model_name>", nil
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /model set <model_id> | <provider_name> <model_name>", nil
|
||||
}
|
||||
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
before, _ := h.getBotSettings(cc)
|
||||
modelResp, err := h.findModelForSelection(cc, cc.Args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -51,18 +98,22 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Chat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
|
||||
return formatChangedValue("Chat model", h.resolveModelName(cc, before.ChatModelID), h.resolveModelName(cc, modelResp.ID)), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "set-heartbeat",
|
||||
Usage: "set-heartbeat <provider_name> <model_name> - Set the heartbeat model",
|
||||
Usage: "set-heartbeat <model_id> | <provider_name> <model_name> - Set the heartbeat model",
|
||||
IsWrite: true,
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if len(cc.Args) < 2 {
|
||||
return "Usage: /model set-heartbeat <provider_name> <model_name>", nil
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /model set-heartbeat <model_id> | <provider_name> <model_name>", nil
|
||||
}
|
||||
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
before, _ := h.getBotSettings(cc)
|
||||
modelResp, err := h.findModelForSelection(cc, cc.Args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -72,7 +123,7 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Heartbeat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
|
||||
return formatChangedValue("Heartbeat model", h.resolveModelName(cc, before.HeartbeatModelID), h.resolveModelName(cc, modelResp.ID)), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
@@ -105,3 +156,86 @@ func (h *Handler) findModelByProviderAndName(cc CommandContext, providerName, mo
|
||||
}
|
||||
return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName)
|
||||
}
|
||||
|
||||
func (h *Handler) findModelForSelection(cc CommandContext, args []string) (models.GetResponse, error) {
|
||||
if h.modelsService == nil {
|
||||
return models.GetResponse{}, errors.New("model service is not available")
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return models.GetResponse{}, errors.New("model identifier is required")
|
||||
}
|
||||
if len(args) == 1 {
|
||||
return h.findModelByIDOrName(cc, args[0])
|
||||
}
|
||||
return h.findModelByProviderAndName(cc, args[0], strings.Join(args[1:], " "))
|
||||
}
|
||||
|
||||
func (h *Handler) findModelByIDOrName(cc CommandContext, target string) (models.GetResponse, error) {
|
||||
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, err
|
||||
}
|
||||
target = strings.TrimSpace(target)
|
||||
if target == "" {
|
||||
return models.GetResponse{}, errors.New("model identifier is required")
|
||||
}
|
||||
for _, item := range items {
|
||||
if strings.EqualFold(item.ModelID, target) {
|
||||
return item, nil
|
||||
}
|
||||
}
|
||||
matches := make([]models.GetResponse, 0, 4)
|
||||
for _, item := range items {
|
||||
if strings.EqualFold(item.Name, target) {
|
||||
matches = append(matches, item)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return models.GetResponse{}, fmt.Errorf("model %q not found", target)
|
||||
case 1:
|
||||
return matches[0], nil
|
||||
default:
|
||||
choices := make([]string, 0, len(matches))
|
||||
for _, item := range matches {
|
||||
choices = append(choices, fmt.Sprintf("%s/%s", h.resolveProviderName(cc, item.ProviderID), item.ModelID))
|
||||
}
|
||||
return models.GetResponse{}, fmt.Errorf("model %q is ambiguous; use a model ID or provider-qualified name (%s)", target, strings.Join(choices, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) filterModelsByProvider(cc CommandContext, items []models.GetResponse, providerName string) []models.GetResponse {
|
||||
providerName = strings.TrimSpace(providerName)
|
||||
if providerName == "" {
|
||||
return items
|
||||
}
|
||||
filtered := make([]models.GetResponse, 0, len(items))
|
||||
for _, item := range items {
|
||||
if strings.EqualFold(h.resolveProviderName(cc, item.ProviderID), providerName) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func modelMarkers(modelID string, settingsResp settings.Settings) []string {
|
||||
var markers []string
|
||||
if modelID == settingsResp.ChatModelID {
|
||||
markers = append(markers, "chat")
|
||||
}
|
||||
if modelID == settingsResp.HeartbeatModelID {
|
||||
markers = append(markers, "heartbeat")
|
||||
}
|
||||
return markers
|
||||
}
|
||||
|
||||
func modelSortRank(model models.GetResponse, settingsResp settings.Settings) int {
|
||||
switch len(modelMarkers(model.ID, settingsResp)) {
|
||||
case 2:
|
||||
return 0
|
||||
case 1:
|
||||
return 1
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ func TestParse_Basic(t *testing.T) {
|
||||
args []string
|
||||
}{
|
||||
{"/help", "help", "", nil},
|
||||
{"/help model", "help", "model", nil},
|
||||
{"/help model set", "help", "model", []string{"set"}},
|
||||
{"/subagent list", "subagent", "list", nil},
|
||||
{"/subagent get mybot", "subagent", "get", []string{"mybot"}},
|
||||
{"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}},
|
||||
|
||||
@@ -8,10 +8,18 @@ import (
|
||||
|
||||
// CommandContext carries execution context for a sub-command.
|
||||
type CommandContext struct {
|
||||
Ctx context.Context
|
||||
BotID string
|
||||
Role string // "owner", "admin", "member", or "" (guest)
|
||||
Args []string
|
||||
Ctx context.Context
|
||||
BotID string
|
||||
Role string // "owner", "admin", "member", or "" (guest)
|
||||
Args []string
|
||||
ChannelIdentityID string
|
||||
UserID string
|
||||
ChannelType string
|
||||
ConversationType string
|
||||
ConversationID string
|
||||
ThreadID string
|
||||
RouteID string
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// SubCommand describes a single sub-command within a resource group.
|
||||
@@ -50,15 +58,38 @@ func (g *CommandGroup) Usage() string {
|
||||
fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description)
|
||||
for _, name := range g.order {
|
||||
sub := g.commands[name]
|
||||
perm := ""
|
||||
desc := subSummary(sub)
|
||||
if sub.IsWrite {
|
||||
perm = " [owner]"
|
||||
desc += " [owner]"
|
||||
}
|
||||
fmt.Fprintf(&b, "- %s%s\n", sub.Usage, perm)
|
||||
fmt.Fprintf(&b, "- %s\n", desc)
|
||||
}
|
||||
fmt.Fprintf(&b, "\nUse /help %s <action> for details.", g.Name)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (g *CommandGroup) ActionHelp(action string) string {
|
||||
sub, ok := g.commands[action]
|
||||
if !ok {
|
||||
return fmt.Sprintf("Unknown action %q for /%s.\n\n%s", action, g.Name, g.Usage())
|
||||
}
|
||||
usage, summary := splitUsage(sub.Usage)
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "/%s %s\n", g.Name, sub.Name)
|
||||
if summary != "" {
|
||||
fmt.Fprintf(&b, "- Summary: %s\n", summary)
|
||||
}
|
||||
if usage == "" {
|
||||
usage = sub.Name
|
||||
}
|
||||
fmt.Fprintf(&b, "- Usage: /%s %s\n", g.Name, usage)
|
||||
if sub.IsWrite {
|
||||
b.WriteString("- Access: owner only\n")
|
||||
}
|
||||
fmt.Fprintf(&b, "- Tip: use /help %s to view sibling actions.", g.Name)
|
||||
return strings.TrimRight(b.String(), "\n")
|
||||
}
|
||||
|
||||
// Registry holds all registered command groups.
|
||||
type Registry struct {
|
||||
groups map[string]*CommandGroup
|
||||
@@ -83,12 +114,47 @@ func (r *Registry) GlobalHelp() string {
|
||||
b.WriteString("/help - Show this help message\n")
|
||||
b.WriteString("/new - Start a new conversation (resets session context)\n")
|
||||
b.WriteString("/stop - Stop the current generation\n\n")
|
||||
for i, name := range r.order {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
for _, name := range r.order {
|
||||
group := r.groups[name]
|
||||
b.WriteString(group.Usage())
|
||||
fmt.Fprintf(&b, "- /%s - %s\n", group.Name, group.Description)
|
||||
}
|
||||
b.WriteString("\nUse /help <group> to view actions, e.g. /help model")
|
||||
return strings.TrimRight(b.String(), "\n")
|
||||
}
|
||||
|
||||
func (r *Registry) GroupHelp(name string) string {
|
||||
group, ok := r.groups[name]
|
||||
if !ok {
|
||||
return fmt.Sprintf("Unknown command group: /%s\n\n%s", name, r.GlobalHelp())
|
||||
}
|
||||
return group.Usage()
|
||||
}
|
||||
|
||||
func (r *Registry) ActionHelp(groupName, action string) string {
|
||||
group, ok := r.groups[groupName]
|
||||
if !ok {
|
||||
return fmt.Sprintf("Unknown command group: /%s\n\n%s", groupName, r.GlobalHelp())
|
||||
}
|
||||
return group.ActionHelp(action)
|
||||
}
|
||||
|
||||
func splitUsage(usage string) (commandUsage string, summary string) {
|
||||
usage = strings.TrimSpace(usage)
|
||||
if usage == "" {
|
||||
return "", ""
|
||||
}
|
||||
parts := strings.SplitN(usage, " - ", 2)
|
||||
commandUsage = strings.TrimSpace(parts[0])
|
||||
if len(parts) > 1 {
|
||||
summary = strings.TrimSpace(parts[1])
|
||||
}
|
||||
return commandUsage, summary
|
||||
}
|
||||
|
||||
func subSummary(sub SubCommand) string {
|
||||
usage, summary := splitUsage(sub.Usage)
|
||||
if summary == "" {
|
||||
return usage
|
||||
}
|
||||
return fmt.Sprintf("%s - %s", sub.Name, summary)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,9 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
Name: "list",
|
||||
Usage: "list - List all search providers",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.searchProvService == nil {
|
||||
return "Search provider service is not available.", nil
|
||||
}
|
||||
items, err := h.searchProvService.List(cc.Ctx, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -20,14 +23,40 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
if len(items) == 0 {
|
||||
return "No search providers found.", nil
|
||||
}
|
||||
records := make([][]kv, 0, len(items))
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
currentRecords := make([][]kv, 0, 1)
|
||||
otherRecords := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
records = append(records, []kv{
|
||||
{"Name", item.Name},
|
||||
label := item.Name
|
||||
record := []kv{
|
||||
{"Name", label},
|
||||
{"Provider", item.Provider},
|
||||
})
|
||||
}
|
||||
if item.ID == settingsResp.SearchProviderID {
|
||||
label += " [current]"
|
||||
record[0].value = label
|
||||
currentRecords = append(currentRecords, record)
|
||||
continue
|
||||
}
|
||||
otherRecords = append(otherRecords, record)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
currentRecords = append(currentRecords, otherRecords...)
|
||||
records := currentRecords
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /search current to inspect the active provider."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show the current search provider",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{{"Search Provider", h.resolveSearchProviderName(cc, settingsResp.SearchProviderID)}}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -38,7 +67,11 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /search set <name>", nil
|
||||
}
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
name := cc.Args[0]
|
||||
before, _ := h.getBotSettings(cc)
|
||||
items, err := h.searchProvService.List(cc.Ctx, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -51,7 +84,7 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Search provider set to %q.", item.Name), nil
|
||||
return formatChangedValue("Search provider", h.resolveSearchProviderName(cc, before.SearchProviderID), item.Name), nil
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Search provider %q not found.", name), nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -106,6 +107,13 @@ func settingsUpdateUsage() string {
|
||||
"- --heartbeat_model_id <id>"
|
||||
}
|
||||
|
||||
func (h *Handler) getBotSettings(cc CommandContext) (settings.Settings, error) {
|
||||
if h.settingsService == nil {
|
||||
return settings.Settings{}, errors.New("settings service is not available")
|
||||
}
|
||||
return h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||
}
|
||||
|
||||
// resolveModelName resolves a model UUID to "model_name (provider_name)".
|
||||
func (h *Handler) resolveModelName(cc CommandContext, modelID string) string {
|
||||
if modelID == "" {
|
||||
|
||||
+94
-57
@@ -6,7 +6,9 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
func (h *Handler) buildStatusGroup() *CommandGroup {
|
||||
@@ -16,78 +18,113 @@ func (h *Handler) buildStatusGroup() *CommandGroup {
|
||||
Name: "show",
|
||||
Usage: "show - Show current session status",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if strings.TrimSpace(cc.SessionID) == "" {
|
||||
return "No active session found for this conversation.", nil
|
||||
}
|
||||
return h.renderSessionStatus(cc, cc.SessionID, "current conversation")
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "latest",
|
||||
Usage: "latest - Show the latest session status for this bot",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Session info is not available.", nil
|
||||
}
|
||||
botUUID, err := parseBotUUID(cc.BotID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "No active session found.", nil
|
||||
return "No session found for this bot.", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, sessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("count messages: %w", err)
|
||||
}
|
||||
|
||||
var usedTokens int64
|
||||
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, sessionID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("get usage: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
usedTokens = latestUsage
|
||||
}
|
||||
|
||||
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, sessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get cache: %w", err)
|
||||
}
|
||||
|
||||
var contextWindowStr string
|
||||
if h.settingsService != nil {
|
||||
s, sErr := h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||
if sErr == nil && s.ChatModelID != "" && h.modelsService != nil {
|
||||
m, mErr := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
|
||||
if mErr == nil && m.Config.ContextWindow != nil {
|
||||
contextWindowStr = formatTokens(int64(*m.Config.ContextWindow))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cacheHitRate float64
|
||||
if cacheRow.TotalInputTokens > 0 {
|
||||
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
|
||||
}
|
||||
|
||||
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, sessionID)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("Session Status:\n\n")
|
||||
fmt.Fprintf(&b, "- Messages: %d\n", msgCount)
|
||||
if contextWindowStr != "" {
|
||||
fmt.Fprintf(&b, "- Context: %s / %s\n", formatTokens(usedTokens), contextWindowStr)
|
||||
} else {
|
||||
fmt.Fprintf(&b, "- Context: %s\n", formatTokens(usedTokens))
|
||||
}
|
||||
fmt.Fprintf(&b, "- Cache Hit Rate: %.1f%%\n", cacheHitRate)
|
||||
fmt.Fprintf(&b, "- Cache Read: %s\n", formatTokens(cacheRow.CacheReadTokens))
|
||||
fmt.Fprintf(&b, "- Cache Write: %s\n", formatTokens(cacheRow.CacheWriteTokens))
|
||||
|
||||
if len(skills) > 0 {
|
||||
fmt.Fprintf(&b, "- Skills: %s\n", strings.Join(skills, ", "))
|
||||
}
|
||||
|
||||
return strings.TrimRight(b.String(), "\n"), nil
|
||||
return h.renderSessionStatus(cc, sessionID.String(), "latest bot session")
|
||||
},
|
||||
})
|
||||
return g
|
||||
}
|
||||
|
||||
func (h *Handler) renderSessionStatus(cc CommandContext, sessionID string, scope string) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Session info is not available.", nil
|
||||
}
|
||||
pgSessionID, err := parseCommandUUID(sessionID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, pgSessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("count messages: %w", err)
|
||||
}
|
||||
|
||||
var usedTokens int64
|
||||
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, pgSessionID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("get usage: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
usedTokens = latestUsage
|
||||
}
|
||||
|
||||
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, pgSessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get cache: %w", err)
|
||||
}
|
||||
|
||||
var cacheHitRate float64
|
||||
if cacheRow.TotalInputTokens > 0 {
|
||||
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
|
||||
}
|
||||
|
||||
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, pgSessionID)
|
||||
|
||||
contextUsage := formatTokens(usedTokens)
|
||||
if contextWindow := h.resolveContextWindow(cc); contextWindow != "" {
|
||||
contextUsage = contextUsage + " / " + contextWindow
|
||||
}
|
||||
|
||||
pairs := []kv{
|
||||
{"Scope", scope},
|
||||
{"Session ID", sessionID},
|
||||
{"Messages", strconv.FormatInt(msgCount, 10)},
|
||||
{"Context", contextUsage},
|
||||
{"Cache Hit Rate", fmt.Sprintf("%.1f%%", cacheHitRate)},
|
||||
{"Cache Read", formatTokens(cacheRow.CacheReadTokens)},
|
||||
{"Cache Write", formatTokens(cacheRow.CacheWriteTokens)},
|
||||
}
|
||||
if len(skills) > 0 {
|
||||
pairs = append(pairs, kv{"Skills", strings.Join(skills, ", ")})
|
||||
}
|
||||
return formatKV(pairs), nil
|
||||
}
|
||||
|
||||
func (h *Handler) resolveContextWindow(cc CommandContext) string {
|
||||
if h.settingsService == nil || h.modelsService == nil {
|
||||
return ""
|
||||
}
|
||||
s, err := h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||
if err != nil || s.ChatModelID == "" {
|
||||
return ""
|
||||
}
|
||||
m, err := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
|
||||
if err != nil || m.Config.ContextWindow == nil {
|
||||
return ""
|
||||
}
|
||||
return formatTokens(int64(*m.Config.ContextWindow))
|
||||
}
|
||||
|
||||
func parseCommandUUID(id string) (pgtype.UUID, error) {
|
||||
parsed, err := uuid.Parse(strings.TrimSpace(id))
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid uuid: %w", err)
|
||||
}
|
||||
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
|
||||
}
|
||||
|
||||
func formatTokens(n int64) string {
|
||||
if n >= 1_000_000 {
|
||||
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
||||
|
||||
@@ -18,6 +18,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
|
||||
Name: "summary",
|
||||
Usage: "summary - Token usage summary (last 7 days)",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Usage info is not available.", nil
|
||||
}
|
||||
botUUID, err := parseBotUUID(cc.BotID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -89,6 +92,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
|
||||
Name: "by-model",
|
||||
Usage: "by-model - Token usage grouped by model",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Usage info is not available.", nil
|
||||
}
|
||||
botUUID, err := parseBotUUID(cc.BotID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
Reference in New Issue
Block a user