mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(command): improve slash command UX
Make slash commands easier to navigate in chat by splitting help into levels, compacting list output, and surfacing current selections for model, search, memory, and browser settings. Also route /status to the active conversation session and add an access inspector so users can understand their current command and ACL context.
This commit is contained in:
@@ -604,6 +604,7 @@ func provideChannelRouter(
|
||||
emailOutboxService,
|
||||
heartbeatService,
|
||||
queries,
|
||||
aclService,
|
||||
&commandSkillLoaderAdapter{handler: containerdHandler},
|
||||
&commandContainerFSAdapter{manager: manager},
|
||||
))
|
||||
@@ -779,6 +780,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
|
||||
if err != nil {
|
||||
return inbound.SessionResult{}, err
|
||||
}
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
||||
if err != nil {
|
||||
|
||||
@@ -492,6 +492,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc
|
||||
emailOutboxService,
|
||||
heartbeatService,
|
||||
queries,
|
||||
aclService,
|
||||
&commandSkillLoaderAdapter{handler: containerdHandler},
|
||||
&commandContainerFSAdapter{manager: manager},
|
||||
))
|
||||
@@ -877,6 +878,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
|
||||
if err != nil {
|
||||
return inbound.SessionResult{}, err
|
||||
}
|
||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||
}
|
||||
|
||||
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
||||
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
||||
if err != nil {
|
||||
|
||||
@@ -71,6 +71,7 @@ type ttsModelResolver interface {
|
||||
// SessionEnsurer resolves or creates an active session for a route.
|
||||
type SessionEnsurer interface {
|
||||
EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error)
|
||||
GetActiveSession(ctx context.Context, routeID string) (SessionResult, error)
|
||||
// CreateNewSession always creates a fresh session and sets it as the
|
||||
// active session for the given route, replacing any previous one.
|
||||
// sessionType defaults to "chat" if empty.
|
||||
@@ -298,11 +299,23 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
if isStopCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
return p.handleStopCommand(ctx, cfg, msg, sender, identity)
|
||||
}
|
||||
if isStatusCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
return p.handleStatusCommand(ctx, cfg, msg, sender, identity)
|
||||
}
|
||||
|
||||
// Skip generic command handler for mode-prefix commands (/btw, /now, /next)
|
||||
// so they pass through to mode detection below.
|
||||
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
|
||||
reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText)
|
||||
reply, err := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
|
||||
BotID: strings.TrimSpace(identity.BotID),
|
||||
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
|
||||
UserID: strings.TrimSpace(identity.UserID),
|
||||
Text: cmdText,
|
||||
ChannelType: msg.Channel.String(),
|
||||
ConversationType: strings.TrimSpace(msg.Conversation.Type),
|
||||
ConversationID: strings.TrimSpace(msg.Conversation.ID),
|
||||
ThreadID: extractThreadID(msg),
|
||||
})
|
||||
if err != nil {
|
||||
reply = "Error: " + err.Error()
|
||||
}
|
||||
@@ -2489,6 +2502,18 @@ func isNewSessionCommand(cmdText string) bool {
|
||||
return parsed.Resource == "new"
|
||||
}
|
||||
|
||||
func isStatusCommand(cmdText string) bool {
|
||||
extracted := command.ExtractCommandText(cmdText)
|
||||
if extracted == "" {
|
||||
return false
|
||||
}
|
||||
parsed, err := command.Parse(extracted)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return parsed.Resource == "status"
|
||||
}
|
||||
|
||||
// resolveNewSessionType determines the session type for /new command.
|
||||
// /new chat → chat, /new discuss → discuss, /new (no arg) → default by context.
|
||||
// WebUI (local channel) always defaults to chat.
|
||||
@@ -2611,3 +2636,83 @@ func (p *ChannelInboundProcessor) handleNewSessionCommand(
|
||||
Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)},
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ChannelInboundProcessor) handleStatusCommand(
|
||||
ctx context.Context,
|
||||
cfg channel.ChannelConfig,
|
||||
msg channel.InboundMessage,
|
||||
sender channel.StreamReplySender,
|
||||
identity InboundIdentity,
|
||||
) error {
|
||||
target := strings.TrimSpace(msg.ReplyTarget)
|
||||
if target == "" {
|
||||
return errors.New("reply target missing for /status command")
|
||||
}
|
||||
if p.routeResolver == nil {
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: "Error: route resolver not configured."},
|
||||
})
|
||||
}
|
||||
if p.commandHandler == nil {
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: "Error: command handler not configured."},
|
||||
})
|
||||
}
|
||||
|
||||
threadID := extractThreadID(msg)
|
||||
routeMetadata := buildRouteMetadata(msg, identity)
|
||||
p.enrichConversationAvatar(ctx, cfg, msg, routeMetadata)
|
||||
resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{
|
||||
BotID: identity.BotID,
|
||||
Platform: msg.Channel.String(),
|
||||
ConversationID: msg.Conversation.ID,
|
||||
ThreadID: threadID,
|
||||
ConversationType: msg.Conversation.Type,
|
||||
ChannelIdentityID: identity.UserID,
|
||||
ChannelConfigID: identity.ChannelConfigID,
|
||||
ReplyTarget: target,
|
||||
Metadata: routeMetadata,
|
||||
})
|
||||
if err != nil {
|
||||
if p.logger != nil {
|
||||
p.logger.Warn("resolve route for /status command failed", slog.Any("error", err))
|
||||
}
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: "Error: failed to resolve conversation route."},
|
||||
})
|
||||
}
|
||||
|
||||
sessionID := ""
|
||||
if p.sessionEnsurer != nil {
|
||||
sess, sessErr := p.sessionEnsurer.GetActiveSession(ctx, resolved.RouteID)
|
||||
if sessErr == nil {
|
||||
sessionID = strings.TrimSpace(sess.ID)
|
||||
} else if p.logger != nil {
|
||||
p.logger.Debug("resolve active session for /status command failed", slog.Any("error", sessErr))
|
||||
}
|
||||
}
|
||||
|
||||
reply, execErr := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
|
||||
BotID: strings.TrimSpace(identity.BotID),
|
||||
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
|
||||
UserID: strings.TrimSpace(identity.UserID),
|
||||
Text: rawTextForCommand(msg, strings.TrimSpace(msg.Message.PlainText())),
|
||||
ChannelType: msg.Channel.String(),
|
||||
ConversationType: strings.TrimSpace(msg.Conversation.Type),
|
||||
ConversationID: strings.TrimSpace(msg.Conversation.ID),
|
||||
ThreadID: threadID,
|
||||
RouteID: strings.TrimSpace(resolved.RouteID),
|
||||
SessionID: sessionID,
|
||||
})
|
||||
if execErr != nil {
|
||||
reply = "Error: " + execErr.Error()
|
||||
}
|
||||
|
||||
return sender.Send(ctx, channel.OutboundMessage{
|
||||
Target: target,
|
||||
Message: channel.Message{Text: reply},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,11 +13,15 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/acl"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/channel/identities"
|
||||
"github.com/memohai/memoh/internal/channel/route"
|
||||
"github.com/memohai/memoh/internal/command"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/media"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
@@ -185,6 +189,71 @@ type fakeChatACL struct {
|
||||
lastReq acl.EvaluateRequest
|
||||
}
|
||||
|
||||
type fakeSessionEnsurer struct {
|
||||
activeSession SessionResult
|
||||
activeErr error
|
||||
lastRouteID string
|
||||
}
|
||||
|
||||
func (f *fakeSessionEnsurer) EnsureActiveSession(_ context.Context, _, routeID, _ string) (SessionResult, error) {
|
||||
f.lastRouteID = routeID
|
||||
if f.activeErr != nil {
|
||||
return SessionResult{}, f.activeErr
|
||||
}
|
||||
return f.activeSession, nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionEnsurer) GetActiveSession(_ context.Context, routeID string) (SessionResult, error) {
|
||||
f.lastRouteID = routeID
|
||||
if f.activeErr != nil {
|
||||
return SessionResult{}, f.activeErr
|
||||
}
|
||||
return f.activeSession, nil
|
||||
}
|
||||
|
||||
func (f *fakeSessionEnsurer) CreateNewSession(_ context.Context, _, routeID, _, _ string) (SessionResult, error) {
|
||||
f.lastRouteID = routeID
|
||||
if f.activeErr != nil {
|
||||
return SessionResult{}, f.activeErr
|
||||
}
|
||||
return f.activeSession, nil
|
||||
}
|
||||
|
||||
type fakeCommandQueries struct {
|
||||
messageCount int64
|
||||
usage int64
|
||||
cacheRow dbsqlc.GetSessionCacheStatsRow
|
||||
skills []string
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
|
||||
return pgtype.UUID{}, errors.New("unexpected latest session lookup")
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
return f.messageCount, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
return f.usage, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
|
||||
return f.cacheRow, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
|
||||
return f.skills, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) {
|
||||
f.calls++
|
||||
f.lastReq = req
|
||||
@@ -548,6 +617,72 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelInboundProcessorStatusUsesRouteSession(t *testing.T) {
|
||||
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-status"}}
|
||||
policySvc := &fakePolicyService{}
|
||||
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-status", RouteID: "route-status"}}
|
||||
gateway := &fakeChatGateway{}
|
||||
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
|
||||
processor.SetSessionEnsurer(&fakeSessionEnsurer{
|
||||
activeSession: SessionResult{ID: "11111111-1111-1111-1111-111111111111", Type: "chat"},
|
||||
})
|
||||
processor.SetCommandHandler(command.NewHandler(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&fakeCommandQueries{
|
||||
messageCount: 9,
|
||||
usage: 512,
|
||||
cacheRow: dbsqlc.GetSessionCacheStatsRow{
|
||||
CacheReadTokens: 64,
|
||||
CacheWriteTokens: 32,
|
||||
TotalInputTokens: 512,
|
||||
},
|
||||
skills: []string{"search"},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
))
|
||||
sender := &fakeReplySender{}
|
||||
|
||||
cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("discord")}
|
||||
msg := channel.InboundMessage{
|
||||
BotID: "bot-1",
|
||||
Channel: channel.ChannelType("discord"),
|
||||
Message: channel.Message{Text: "/status"},
|
||||
ReplyTarget: "discord:status",
|
||||
Sender: channel.Identity{SubjectID: "user-1"},
|
||||
Conversation: channel.Conversation{
|
||||
ID: "conv-status",
|
||||
Type: channel.ConversationTypePrivate,
|
||||
},
|
||||
}
|
||||
|
||||
if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(sender.sent) != 1 {
|
||||
t.Fatalf("expected one status reply, got %d", len(sender.sent))
|
||||
}
|
||||
if !strings.Contains(sender.sent[0].Message.Text, "- Scope: current conversation") {
|
||||
t.Fatalf("expected current conversation scope, got %q", sender.sent[0].Message.Text)
|
||||
}
|
||||
if !strings.Contains(sender.sent[0].Message.Text, "- Session ID: 11111111-1111-1111-1111-111111111111") {
|
||||
t.Fatalf("expected active route session in reply, got %q", sender.sent[0].Message.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/acl"
|
||||
)
|
||||
|
||||
func (h *Handler) buildAccessGroup() *CommandGroup {
|
||||
g := newCommandGroup("access", "Inspect identity and permission context")
|
||||
g.DefaultAction = "show"
|
||||
g.Register(SubCommand{
|
||||
Name: "show",
|
||||
Usage: "show - Show current identity, write access, and chat ACL context",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
writeAccess := "no"
|
||||
if cc.Role == "owner" {
|
||||
writeAccess = "yes"
|
||||
}
|
||||
|
||||
pairs := []kv{
|
||||
{"Channel Identity", fallbackValue(cc.ChannelIdentityID)},
|
||||
{"Linked User", fallbackValue(cc.UserID)},
|
||||
{"Bot Role", fallbackValue(cc.Role)},
|
||||
{"Write Commands", writeAccess},
|
||||
{"Channel", fallbackValue(cc.ChannelType)},
|
||||
{"Conversation Type", fallbackValue(cc.ConversationType)},
|
||||
{"Conversation ID", fallbackValue(cc.ConversationID)},
|
||||
{"Thread ID", fallbackValue(cc.ThreadID)},
|
||||
}
|
||||
if strings.TrimSpace(cc.RouteID) != "" {
|
||||
pairs = append(pairs, kv{"Route ID", cc.RouteID})
|
||||
}
|
||||
if strings.TrimSpace(cc.SessionID) != "" {
|
||||
pairs = append(pairs, kv{"Session ID", cc.SessionID})
|
||||
}
|
||||
|
||||
aclStatus := "unavailable"
|
||||
if h.aclEvaluator != nil && strings.TrimSpace(cc.ChannelType) != "" {
|
||||
allowed, err := h.aclEvaluator.Evaluate(cc.Ctx, acl.EvaluateRequest{
|
||||
BotID: cc.BotID,
|
||||
ChannelIdentityID: cc.ChannelIdentityID,
|
||||
ChannelType: cc.ChannelType,
|
||||
SourceScope: acl.SourceScope{
|
||||
ConversationType: cc.ConversationType,
|
||||
ConversationID: cc.ConversationID,
|
||||
ThreadID: cc.ThreadID,
|
||||
},
|
||||
})
|
||||
switch {
|
||||
case err != nil:
|
||||
aclStatus = "error: " + err.Error()
|
||||
case allowed:
|
||||
aclStatus = "allow"
|
||||
default:
|
||||
aclStatus = "deny"
|
||||
}
|
||||
}
|
||||
pairs = append(pairs, kv{"Chat ACL", aclStatus})
|
||||
|
||||
return formatKV(pairs), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
}
|
||||
|
||||
func fallbackValue(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "(none)"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func formatChangedValue(label, before, after string) string {
|
||||
return fmt.Sprintf("%s: %s -> %s", label, fallbackValue(before), fallbackValue(after))
|
||||
}
|
||||
@@ -13,6 +13,9 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
Name: "list",
|
||||
Usage: "list - List all browser contexts",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.browserCtxService == nil {
|
||||
return "Browser context service is not available.", nil
|
||||
}
|
||||
items, err := h.browserCtxService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -20,13 +23,37 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
if len(items) == 0 {
|
||||
return "No browser contexts found.", nil
|
||||
}
|
||||
records := make([][]kv, 0, len(items))
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
currentRecords := make([][]kv, 0, 1)
|
||||
otherRecords := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
records = append(records, []kv{
|
||||
{"Name", item.Name},
|
||||
})
|
||||
label := item.Name
|
||||
record := []kv{{"Name", label}}
|
||||
if item.ID == settingsResp.BrowserContextID {
|
||||
label += " [current]"
|
||||
record[0].value = label
|
||||
currentRecords = append(currentRecords, record)
|
||||
continue
|
||||
}
|
||||
otherRecords = append(otherRecords, record)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
currentRecords = append(currentRecords, otherRecords...)
|
||||
records := currentRecords
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /browser current to inspect the active context."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show the current browser context",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{{"Browser Context", h.resolveBrowserContextName(cc, settingsResp.BrowserContextID)}}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -37,7 +64,11 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /browser set <name>", nil
|
||||
}
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
name := cc.Args[0]
|
||||
before, _ := h.getBotSettings(cc)
|
||||
items, err := h.browserCtxService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -50,7 +81,7 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Browser context set to %q.", item.Name), nil
|
||||
return formatChangedValue("Browser context", h.resolveBrowserContextName(cc, before.BrowserContextID), item.Name), nil
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Browser context %q not found.", name), nil
|
||||
|
||||
@@ -16,5 +16,6 @@ func (h *Handler) buildRegistry() *Registry {
|
||||
r.RegisterGroup(h.buildSkillGroup())
|
||||
r.RegisterGroup(h.buildFSGroup())
|
||||
r.RegisterGroup(h.buildStatusGroup())
|
||||
r.RegisterGroup(h.buildAccessGroup())
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
||||
{"Provider", item.Provider},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /email bindings to inspect bot bindings."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -46,7 +46,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
||||
{"Permissions", perms},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /email outbox to inspect recent sends."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -70,7 +70,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
||||
{"Sent", item.SentAt.Format("01-02 15:04")},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, 10, "Use the Web UI for older outbox entries."), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
|
||||
@@ -6,19 +6,10 @@ import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const defaultListLimit = 12
|
||||
|
||||
// formatItems renders a list of records as a Markdown-style list.
|
||||
// Each record is a slice of kv pairs; the first pair's value is used as the
|
||||
// bullet title, and subsequent pairs are indented beneath it.
|
||||
//
|
||||
// Example output:
|
||||
//
|
||||
// - mybot
|
||||
// Description: A helpful assistant
|
||||
// ID: abc123
|
||||
//
|
||||
// - another
|
||||
// Description: Something else
|
||||
// ID: def456
|
||||
// Each record is rendered on a single line so long lists stay readable in IM.
|
||||
func formatItems(items [][]kv) string {
|
||||
if len(items) == 0 {
|
||||
return ""
|
||||
@@ -31,14 +22,41 @@ func formatItems(items [][]kv) string {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
fmt.Fprintf(&b, "- %s\n", record[0].value)
|
||||
fmt.Fprintf(&b, "- %s", record[0].value)
|
||||
extras := make([]string, 0, len(record)-1)
|
||||
for _, pair := range record[1:] {
|
||||
fmt.Fprintf(&b, " %s: %s\n", pair.key, pair.value)
|
||||
if strings.TrimSpace(pair.value) == "" {
|
||||
continue
|
||||
}
|
||||
extras = append(extras, fmt.Sprintf("%s: %s", pair.key, pair.value))
|
||||
}
|
||||
if len(extras) > 0 {
|
||||
fmt.Fprintf(&b, " | %s", strings.Join(extras, " | "))
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatLimitedItems(items [][]kv, limit int, hint string) string {
|
||||
if len(items) == 0 {
|
||||
return ""
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = defaultListLimit
|
||||
}
|
||||
total := len(items)
|
||||
if total <= limit {
|
||||
return formatItems(items)
|
||||
}
|
||||
shown := items[:limit]
|
||||
result := formatItems(shown)
|
||||
suffix := fmt.Sprintf("Showing %d of %d items.", len(shown), total)
|
||||
if strings.TrimSpace(hint) != "" {
|
||||
suffix += " " + strings.TrimSpace(hint)
|
||||
}
|
||||
return result + "\n\n" + suffix
|
||||
}
|
||||
|
||||
// formatKV renders key-value pairs as a simple Markdown list.
|
||||
//
|
||||
// Example output:
|
||||
|
||||
+57
-13
@@ -4,10 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/browsercontexts"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
emailpkg "github.com/memohai/memoh/internal/email"
|
||||
"github.com/memohai/memoh/internal/heartbeat"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
@@ -56,13 +56,28 @@ type Handler struct {
|
||||
emailService *emailpkg.Service
|
||||
emailOutboxService *emailpkg.OutboxService
|
||||
heartbeatService *heartbeat.Service
|
||||
queries *dbsqlc.Queries
|
||||
queries CommandQueries
|
||||
aclEvaluator AccessEvaluator
|
||||
skillLoader SkillLoader
|
||||
containerFS ContainerFS
|
||||
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// ExecuteInput carries the caller identity and channel context for command execution.
|
||||
type ExecuteInput struct {
|
||||
BotID string
|
||||
ChannelIdentityID string
|
||||
UserID string
|
||||
Text string
|
||||
ChannelType string
|
||||
ConversationType string
|
||||
ConversationID string
|
||||
ThreadID string
|
||||
RouteID string
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with all required services.
|
||||
func NewHandler(
|
||||
log *slog.Logger,
|
||||
@@ -78,7 +93,8 @@ func NewHandler(
|
||||
emailService *emailpkg.Service,
|
||||
emailOutboxService *emailpkg.OutboxService,
|
||||
heartbeatService *heartbeat.Service,
|
||||
queries *dbsqlc.Queries,
|
||||
queries CommandQueries,
|
||||
aclEvaluator AccessEvaluator,
|
||||
skillLoader SkillLoader,
|
||||
containerFS ContainerFS,
|
||||
) *Handler {
|
||||
@@ -99,6 +115,7 @@ func NewHandler(
|
||||
emailOutboxService: emailOutboxService,
|
||||
heartbeatService: heartbeatService,
|
||||
queries: queries,
|
||||
aclEvaluator: aclEvaluator,
|
||||
skillLoader: skillLoader,
|
||||
containerFS: containerFS,
|
||||
logger: log.With(slog.String("component", "command")),
|
||||
@@ -140,7 +157,16 @@ func (h *Handler) IsCommand(text string) bool {
|
||||
|
||||
// Execute parses and runs a slash command, returning the text reply.
|
||||
func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) {
|
||||
cmdText := ExtractCommandText(text)
|
||||
return h.ExecuteWithInput(ctx, ExecuteInput{
|
||||
BotID: botID,
|
||||
ChannelIdentityID: channelIdentityID,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
// ExecuteWithInput parses and runs a slash command with channel/session context.
|
||||
func (h *Handler) ExecuteWithInput(ctx context.Context, input ExecuteInput) (string, error) {
|
||||
cmdText := ExtractCommandText(input.Text)
|
||||
if cmdText == "" {
|
||||
return h.registry.GlobalHelp(), nil
|
||||
}
|
||||
@@ -151,12 +177,16 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
|
||||
|
||||
// Resolve the user's role in this bot.
|
||||
role := ""
|
||||
if h.roleResolver != nil && channelIdentityID != "" {
|
||||
r, err := h.roleResolver.GetMemberRole(ctx, botID, channelIdentityID)
|
||||
roleIdentityID := input.ChannelIdentityID
|
||||
if strings.TrimSpace(input.UserID) != "" {
|
||||
roleIdentityID = strings.TrimSpace(input.UserID)
|
||||
}
|
||||
if h.roleResolver != nil && roleIdentityID != "" {
|
||||
r, err := h.roleResolver.GetMemberRole(ctx, input.BotID, roleIdentityID)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to resolve member role",
|
||||
slog.String("bot_id", botID),
|
||||
slog.String("channel_identity_id", channelIdentityID),
|
||||
slog.String("bot_id", input.BotID),
|
||||
slog.String("role_identity_id", roleIdentityID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
} else {
|
||||
@@ -165,15 +195,29 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
|
||||
}
|
||||
|
||||
cc := CommandContext{
|
||||
Ctx: ctx,
|
||||
BotID: botID,
|
||||
Role: role,
|
||||
Args: parsed.Args,
|
||||
Ctx: ctx,
|
||||
BotID: input.BotID,
|
||||
Role: role,
|
||||
Args: parsed.Args,
|
||||
ChannelIdentityID: strings.TrimSpace(input.ChannelIdentityID),
|
||||
UserID: strings.TrimSpace(input.UserID),
|
||||
ChannelType: strings.TrimSpace(input.ChannelType),
|
||||
ConversationType: strings.TrimSpace(input.ConversationType),
|
||||
ConversationID: strings.TrimSpace(input.ConversationID),
|
||||
ThreadID: strings.TrimSpace(input.ThreadID),
|
||||
RouteID: strings.TrimSpace(input.RouteID),
|
||||
SessionID: strings.TrimSpace(input.SessionID),
|
||||
}
|
||||
|
||||
// /help
|
||||
if parsed.Resource == "help" {
|
||||
return h.registry.GlobalHelp(), nil
|
||||
if parsed.Action == "" {
|
||||
return h.registry.GlobalHelp(), nil
|
||||
}
|
||||
if len(parsed.Args) == 0 {
|
||||
return h.registry.GroupHelp(parsed.Action), nil
|
||||
}
|
||||
return h.registry.ActionHelp(parsed.Action, parsed.Args[0]), nil
|
||||
}
|
||||
|
||||
// Top-level commands (e.g. /new) are handled by the channel inbound
|
||||
|
||||
@@ -5,6 +5,10 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
@@ -25,9 +29,58 @@ type fakeScheduleService struct {
|
||||
items []schedule.Schedule
|
||||
}
|
||||
|
||||
type fakeCommandQueries struct {
|
||||
latestSessionID pgtype.UUID
|
||||
latestSessionErr error
|
||||
messageCount int64
|
||||
latestUsage int64
|
||||
latestUsageErr error
|
||||
cacheRow dbsqlc.GetSessionCacheStatsRow
|
||||
cacheErr error
|
||||
skills []string
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
|
||||
return f.latestSessionID, f.latestSessionErr
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
return f.messageCount, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||
if f.latestUsageErr != nil {
|
||||
return 0, f.latestUsageErr
|
||||
}
|
||||
return f.latestUsage, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
|
||||
if f.cacheErr != nil {
|
||||
return dbsqlc.GetSessionCacheStatsRow{}, f.cacheErr
|
||||
}
|
||||
return f.cacheRow, nil
|
||||
}
|
||||
|
||||
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
|
||||
return f.skills, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// newTestHandler creates a Handler with nil services for use in tests.
|
||||
func newTestHandler(roleResolver MemberRoleResolver) *Handler {
|
||||
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func newTestHandlerWithQueries(roleResolver MemberRoleResolver, queries CommandQueries) *Handler {
|
||||
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, queries, nil, nil, nil)
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
@@ -73,6 +126,42 @@ func TestExecute_Help(t *testing.T) {
|
||||
if !strings.Contains(result, "Available commands") {
|
||||
t.Errorf("expected help text, got: %s", result)
|
||||
}
|
||||
if strings.Contains(result, "set-heartbeat") {
|
||||
t.Errorf("top-level help should not expand nested actions, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- /model - Manage bot models") {
|
||||
t.Errorf("expected top-level model entry, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_HelpGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "/model - Manage bot models") {
|
||||
t.Errorf("expected group help, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- set - Set the chat model [owner]") {
|
||||
t.Errorf("expected compact action summary, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_HelpAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model set")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "Usage: /model set <model_id> | <provider_name> <model_name>") {
|
||||
t.Errorf("expected action usage, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "Access: owner only") {
|
||||
t.Errorf("expected owner hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_UnknownCommand(t *testing.T) {
|
||||
@@ -207,8 +296,8 @@ func TestFormatItems(t *testing.T) {
|
||||
if !strings.Contains(result, "- foo") {
|
||||
t.Errorf("expected '- foo' bullet, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, " Type: bar") {
|
||||
t.Errorf("expected indented 'Type: bar', got: %s", result)
|
||||
if !strings.Contains(result, "- foo | Type: bar") {
|
||||
t.Errorf("expected compact line entry, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- longname") {
|
||||
t.Errorf("expected '- longname' bullet, got: %s", result)
|
||||
@@ -255,7 +344,7 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
|
||||
for _, group := range []string{
|
||||
"schedule", "mcp", "settings",
|
||||
"model", "memory", "search", "browser", "usage",
|
||||
"email", "heartbeat", "skill", "fs",
|
||||
"email", "heartbeat", "skill", "fs", "access",
|
||||
} {
|
||||
if !strings.Contains(help, "/"+group) {
|
||||
t.Errorf("missing /%s in global help", group)
|
||||
@@ -263,6 +352,86 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithInput_Access(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||
result, err := h.ExecuteWithInput(context.Background(), ExecuteInput{
|
||||
BotID: "bot-1",
|
||||
ChannelIdentityID: "channel-id-1",
|
||||
UserID: "user-id-1",
|
||||
Text: "/access",
|
||||
ChannelType: "discord",
|
||||
ConversationType: "thread",
|
||||
ConversationID: "conv-1",
|
||||
ThreadID: "thread-1",
|
||||
RouteID: "route-1",
|
||||
SessionID: "session-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "- Channel Identity: channel-id-1") {
|
||||
t.Errorf("expected channel identity in access output, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- Write Commands: yes") {
|
||||
t.Errorf("expected write access in access output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_StatusLatest(t *testing.T) {
|
||||
t.Parallel()
|
||||
sessionUUID := pgtype.UUID{}
|
||||
copy(sessionUUID.Bytes[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
sessionUUID.Valid = true
|
||||
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
|
||||
latestSessionID: sessionUUID,
|
||||
messageCount: 42,
|
||||
latestUsage: 1200,
|
||||
cacheRow: dbsqlc.GetSessionCacheStatsRow{
|
||||
CacheReadTokens: 300,
|
||||
CacheWriteTokens: 150,
|
||||
TotalInputTokens: 1200,
|
||||
},
|
||||
skills: []string{"search", "browser"},
|
||||
})
|
||||
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "- Scope: latest bot session") {
|
||||
t.Errorf("expected latest scope, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "- Messages: 42") {
|
||||
t.Errorf("expected message count, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_StatusLatestNoRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
|
||||
latestSessionErr: pgx.ErrNoRows,
|
||||
})
|
||||
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "No session found for this bot.") {
|
||||
t.Errorf("expected no session message, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute_StatusShowWithoutSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{})
|
||||
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/status")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(result, "No active session found for this conversation.") {
|
||||
t.Errorf("expected route-aware no session message, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify write commands are tagged with [owner] in usage.
|
||||
func TestUsage_OwnerTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h *Handler) buildHeartbeatGroup() *CommandGroup {
|
||||
}
|
||||
records = append(records, rec)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, 10, "Use the Web UI for older heartbeat logs."), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package command
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/acl"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
// Skill represents a single skill loaded from a bot's container.
|
||||
type Skill struct {
|
||||
@@ -25,3 +32,20 @@ type ContainerFS interface {
|
||||
ListDir(ctx context.Context, botID, path string) ([]FSEntry, error)
|
||||
ReadFile(ctx context.Context, botID, path string) (string, error)
|
||||
}
|
||||
|
||||
// CommandQueries captures the sqlc methods used by slash commands.
|
||||
// *dbsqlc.Queries satisfies this interface directly.
|
||||
type CommandQueries interface {
|
||||
GetLatestSessionIDByBot(ctx context.Context, botID pgtype.UUID) (pgtype.UUID, error)
|
||||
CountMessagesBySession(ctx context.Context, sessionID pgtype.UUID) (int64, error)
|
||||
GetLatestAssistantUsage(ctx context.Context, sessionID pgtype.UUID) (int64, error)
|
||||
GetSessionCacheStats(ctx context.Context, sessionID pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error)
|
||||
GetSessionUsedSkills(ctx context.Context, sessionID pgtype.UUID) ([]string, error)
|
||||
GetTokenUsageByDayAndType(ctx context.Context, arg dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error)
|
||||
GetTokenUsageByModel(ctx context.Context, arg dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error)
|
||||
}
|
||||
|
||||
// AccessEvaluator checks whether the current channel context may trigger chat.
|
||||
type AccessEvaluator interface {
|
||||
Evaluate(ctx context.Context, req acl.EvaluateRequest) (bool, error)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func (h *Handler) buildMCPGroup() *CommandGroup {
|
||||
{"Status", item.Status},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /mcp get <name> for full details."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
|
||||
@@ -13,6 +13,9 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
Name: "list",
|
||||
Usage: "list - List all memory providers",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.memProvService == nil {
|
||||
return "Memory provider service is not available.", nil
|
||||
}
|
||||
items, err := h.memProvService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -20,18 +23,44 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
if len(items) == 0 {
|
||||
return "No memory providers found.", nil
|
||||
}
|
||||
records := make([][]kv, 0, len(items))
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
currentRecords := make([][]kv, 0, 1)
|
||||
otherRecords := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
def := ""
|
||||
if item.IsDefault {
|
||||
def = " (default)"
|
||||
}
|
||||
records = append(records, []kv{
|
||||
{"Name", item.Name + def},
|
||||
label := item.Name + def
|
||||
record := []kv{
|
||||
{"Name", label},
|
||||
{"Provider", item.Provider},
|
||||
})
|
||||
}
|
||||
if item.ID == settingsResp.MemoryProviderID {
|
||||
label += " [current]"
|
||||
record[0].value = label
|
||||
currentRecords = append(currentRecords, record)
|
||||
continue
|
||||
}
|
||||
otherRecords = append(otherRecords, record)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
currentRecords = append(currentRecords, otherRecords...)
|
||||
records := currentRecords
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /memory current to inspect the active provider."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show the current memory provider",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{{"Memory Provider", h.resolveMemoryProviderName(cc, settingsResp.MemoryProviderID)}}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -42,7 +71,11 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /memory set <name>", nil
|
||||
}
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
name := cc.Args[0]
|
||||
before, _ := h.getBotSettings(cc)
|
||||
items, err := h.memProvService.List(cc.Ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -55,7 +88,7 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Memory provider set to %q.", item.Name), nil
|
||||
return formatChangedValue("Memory provider", h.resolveMemoryProviderName(cc, before.MemoryProviderID), item.Name), nil
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Memory provider %q not found.", name), nil
|
||||
|
||||
+147
-13
@@ -1,7 +1,9 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
@@ -12,36 +14,81 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
||||
g := newCommandGroup("model", "Manage bot models")
|
||||
g.Register(SubCommand{
|
||||
Name: "list",
|
||||
Usage: "list - List all available chat models",
|
||||
Usage: "list [provider_name] - List available chat models",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.modelsService == nil {
|
||||
return "Model service is not available.", nil
|
||||
}
|
||||
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
filterProvider := ""
|
||||
if len(cc.Args) > 0 {
|
||||
filterProvider = strings.TrimSpace(strings.Join(cc.Args, " "))
|
||||
}
|
||||
items = h.filterModelsByProvider(cc, items, filterProvider)
|
||||
if len(items) == 0 {
|
||||
if filterProvider != "" {
|
||||
return fmt.Sprintf("No chat models found for provider %q.", filterProvider), nil
|
||||
}
|
||||
return "No chat models found.", nil
|
||||
}
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
sort.SliceStable(items, func(i, j int) bool {
|
||||
return modelSortRank(items[i], settingsResp) < modelSortRank(items[j], settingsResp)
|
||||
})
|
||||
records := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
provName := h.resolveProviderName(cc, item.ProviderID)
|
||||
label := item.Name
|
||||
markers := modelMarkers(item.ID, settingsResp)
|
||||
if len(markers) > 0 {
|
||||
label += " [" + strings.Join(markers, ", ") + "]"
|
||||
}
|
||||
records = append(records, []kv{
|
||||
{"Model", item.Name},
|
||||
{"Model", label},
|
||||
{"Provider", provName},
|
||||
{"Model ID", item.ModelID},
|
||||
})
|
||||
}
|
||||
return formatItems(records), nil
|
||||
hint := "Use /model current to inspect active selections."
|
||||
if filterProvider == "" {
|
||||
hint = "Use /model list <provider_name> to narrow results."
|
||||
}
|
||||
return formatLimitedItems(records, defaultListLimit, hint), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show current chat and heartbeat models",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{
|
||||
{"Chat Model", h.resolveModelName(cc, settingsResp.ChatModelID)},
|
||||
{"Heartbeat Model", h.resolveModelName(cc, settingsResp.HeartbeatModelID)},
|
||||
}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "set",
|
||||
Usage: "set <provider_name> <model_name> - Set the chat model",
|
||||
Usage: "set <model_id> | <provider_name> <model_name> - Set the chat model",
|
||||
IsWrite: true,
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if len(cc.Args) < 2 {
|
||||
return "Usage: /model set <provider_name> <model_name>", nil
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /model set <model_id> | <provider_name> <model_name>", nil
|
||||
}
|
||||
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
before, _ := h.getBotSettings(cc)
|
||||
modelResp, err := h.findModelForSelection(cc, cc.Args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -51,18 +98,22 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Chat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
|
||||
return formatChangedValue("Chat model", h.resolveModelName(cc, before.ChatModelID), h.resolveModelName(cc, modelResp.ID)), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "set-heartbeat",
|
||||
Usage: "set-heartbeat <provider_name> <model_name> - Set the heartbeat model",
|
||||
Usage: "set-heartbeat <model_id> | <provider_name> <model_name> - Set the heartbeat model",
|
||||
IsWrite: true,
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if len(cc.Args) < 2 {
|
||||
return "Usage: /model set-heartbeat <provider_name> <model_name>", nil
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /model set-heartbeat <model_id> | <provider_name> <model_name>", nil
|
||||
}
|
||||
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
before, _ := h.getBotSettings(cc)
|
||||
modelResp, err := h.findModelForSelection(cc, cc.Args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -72,7 +123,7 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Heartbeat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
|
||||
return formatChangedValue("Heartbeat model", h.resolveModelName(cc, before.HeartbeatModelID), h.resolveModelName(cc, modelResp.ID)), nil
|
||||
},
|
||||
})
|
||||
return g
|
||||
@@ -105,3 +156,86 @@ func (h *Handler) findModelByProviderAndName(cc CommandContext, providerName, mo
|
||||
}
|
||||
return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName)
|
||||
}
|
||||
|
||||
func (h *Handler) findModelForSelection(cc CommandContext, args []string) (models.GetResponse, error) {
|
||||
if h.modelsService == nil {
|
||||
return models.GetResponse{}, errors.New("model service is not available")
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return models.GetResponse{}, errors.New("model identifier is required")
|
||||
}
|
||||
if len(args) == 1 {
|
||||
return h.findModelByIDOrName(cc, args[0])
|
||||
}
|
||||
return h.findModelByProviderAndName(cc, args[0], strings.Join(args[1:], " "))
|
||||
}
|
||||
|
||||
func (h *Handler) findModelByIDOrName(cc CommandContext, target string) (models.GetResponse, error) {
|
||||
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, err
|
||||
}
|
||||
target = strings.TrimSpace(target)
|
||||
if target == "" {
|
||||
return models.GetResponse{}, errors.New("model identifier is required")
|
||||
}
|
||||
for _, item := range items {
|
||||
if strings.EqualFold(item.ModelID, target) {
|
||||
return item, nil
|
||||
}
|
||||
}
|
||||
matches := make([]models.GetResponse, 0, 4)
|
||||
for _, item := range items {
|
||||
if strings.EqualFold(item.Name, target) {
|
||||
matches = append(matches, item)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return models.GetResponse{}, fmt.Errorf("model %q not found", target)
|
||||
case 1:
|
||||
return matches[0], nil
|
||||
default:
|
||||
choices := make([]string, 0, len(matches))
|
||||
for _, item := range matches {
|
||||
choices = append(choices, fmt.Sprintf("%s/%s", h.resolveProviderName(cc, item.ProviderID), item.ModelID))
|
||||
}
|
||||
return models.GetResponse{}, fmt.Errorf("model %q is ambiguous; use a model ID or provider-qualified name (%s)", target, strings.Join(choices, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) filterModelsByProvider(cc CommandContext, items []models.GetResponse, providerName string) []models.GetResponse {
|
||||
providerName = strings.TrimSpace(providerName)
|
||||
if providerName == "" {
|
||||
return items
|
||||
}
|
||||
filtered := make([]models.GetResponse, 0, len(items))
|
||||
for _, item := range items {
|
||||
if strings.EqualFold(h.resolveProviderName(cc, item.ProviderID), providerName) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func modelMarkers(modelID string, settingsResp settings.Settings) []string {
|
||||
var markers []string
|
||||
if modelID == settingsResp.ChatModelID {
|
||||
markers = append(markers, "chat")
|
||||
}
|
||||
if modelID == settingsResp.HeartbeatModelID {
|
||||
markers = append(markers, "heartbeat")
|
||||
}
|
||||
return markers
|
||||
}
|
||||
|
||||
func modelSortRank(model models.GetResponse, settingsResp settings.Settings) int {
|
||||
switch len(modelMarkers(model.ID, settingsResp)) {
|
||||
case 2:
|
||||
return 0
|
||||
case 1:
|
||||
return 1
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ func TestParse_Basic(t *testing.T) {
|
||||
args []string
|
||||
}{
|
||||
{"/help", "help", "", nil},
|
||||
{"/help model", "help", "model", nil},
|
||||
{"/help model set", "help", "model", []string{"set"}},
|
||||
{"/subagent list", "subagent", "list", nil},
|
||||
{"/subagent get mybot", "subagent", "get", []string{"mybot"}},
|
||||
{"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}},
|
||||
|
||||
@@ -8,10 +8,18 @@ import (
|
||||
|
||||
// CommandContext carries execution context for a sub-command.
|
||||
type CommandContext struct {
|
||||
Ctx context.Context
|
||||
BotID string
|
||||
Role string // "owner", "admin", "member", or "" (guest)
|
||||
Args []string
|
||||
Ctx context.Context
|
||||
BotID string
|
||||
Role string // "owner", "admin", "member", or "" (guest)
|
||||
Args []string
|
||||
ChannelIdentityID string
|
||||
UserID string
|
||||
ChannelType string
|
||||
ConversationType string
|
||||
ConversationID string
|
||||
ThreadID string
|
||||
RouteID string
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// SubCommand describes a single sub-command within a resource group.
|
||||
@@ -50,15 +58,38 @@ func (g *CommandGroup) Usage() string {
|
||||
fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description)
|
||||
for _, name := range g.order {
|
||||
sub := g.commands[name]
|
||||
perm := ""
|
||||
desc := subSummary(sub)
|
||||
if sub.IsWrite {
|
||||
perm = " [owner]"
|
||||
desc += " [owner]"
|
||||
}
|
||||
fmt.Fprintf(&b, "- %s%s\n", sub.Usage, perm)
|
||||
fmt.Fprintf(&b, "- %s\n", desc)
|
||||
}
|
||||
fmt.Fprintf(&b, "\nUse /help %s <action> for details.", g.Name)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (g *CommandGroup) ActionHelp(action string) string {
|
||||
sub, ok := g.commands[action]
|
||||
if !ok {
|
||||
return fmt.Sprintf("Unknown action %q for /%s.\n\n%s", action, g.Name, g.Usage())
|
||||
}
|
||||
usage, summary := splitUsage(sub.Usage)
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "/%s %s\n", g.Name, sub.Name)
|
||||
if summary != "" {
|
||||
fmt.Fprintf(&b, "- Summary: %s\n", summary)
|
||||
}
|
||||
if usage == "" {
|
||||
usage = sub.Name
|
||||
}
|
||||
fmt.Fprintf(&b, "- Usage: /%s %s\n", g.Name, usage)
|
||||
if sub.IsWrite {
|
||||
b.WriteString("- Access: owner only\n")
|
||||
}
|
||||
fmt.Fprintf(&b, "- Tip: use /help %s to view sibling actions.", g.Name)
|
||||
return strings.TrimRight(b.String(), "\n")
|
||||
}
|
||||
|
||||
// Registry holds all registered command groups.
|
||||
type Registry struct {
|
||||
groups map[string]*CommandGroup
|
||||
@@ -83,12 +114,47 @@ func (r *Registry) GlobalHelp() string {
|
||||
b.WriteString("/help - Show this help message\n")
|
||||
b.WriteString("/new - Start a new conversation (resets session context)\n")
|
||||
b.WriteString("/stop - Stop the current generation\n\n")
|
||||
for i, name := range r.order {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
for _, name := range r.order {
|
||||
group := r.groups[name]
|
||||
b.WriteString(group.Usage())
|
||||
fmt.Fprintf(&b, "- /%s - %s\n", group.Name, group.Description)
|
||||
}
|
||||
b.WriteString("\nUse /help <group> to view actions, e.g. /help model")
|
||||
return strings.TrimRight(b.String(), "\n")
|
||||
}
|
||||
|
||||
func (r *Registry) GroupHelp(name string) string {
|
||||
group, ok := r.groups[name]
|
||||
if !ok {
|
||||
return fmt.Sprintf("Unknown command group: /%s\n\n%s", name, r.GlobalHelp())
|
||||
}
|
||||
return group.Usage()
|
||||
}
|
||||
|
||||
func (r *Registry) ActionHelp(groupName, action string) string {
|
||||
group, ok := r.groups[groupName]
|
||||
if !ok {
|
||||
return fmt.Sprintf("Unknown command group: /%s\n\n%s", groupName, r.GlobalHelp())
|
||||
}
|
||||
return group.ActionHelp(action)
|
||||
}
|
||||
|
||||
func splitUsage(usage string) (commandUsage string, summary string) {
|
||||
usage = strings.TrimSpace(usage)
|
||||
if usage == "" {
|
||||
return "", ""
|
||||
}
|
||||
parts := strings.SplitN(usage, " - ", 2)
|
||||
commandUsage = strings.TrimSpace(parts[0])
|
||||
if len(parts) > 1 {
|
||||
summary = strings.TrimSpace(parts[1])
|
||||
}
|
||||
return commandUsage, summary
|
||||
}
|
||||
|
||||
func subSummary(sub SubCommand) string {
|
||||
usage, summary := splitUsage(sub.Usage)
|
||||
if summary == "" {
|
||||
return usage
|
||||
}
|
||||
return fmt.Sprintf("%s - %s", sub.Name, summary)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,9 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
Name: "list",
|
||||
Usage: "list - List all search providers",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.searchProvService == nil {
|
||||
return "Search provider service is not available.", nil
|
||||
}
|
||||
items, err := h.searchProvService.List(cc.Ctx, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -20,14 +23,40 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
if len(items) == 0 {
|
||||
return "No search providers found.", nil
|
||||
}
|
||||
records := make([][]kv, 0, len(items))
|
||||
settingsResp, _ := h.getBotSettings(cc)
|
||||
currentRecords := make([][]kv, 0, 1)
|
||||
otherRecords := make([][]kv, 0, len(items))
|
||||
for _, item := range items {
|
||||
records = append(records, []kv{
|
||||
{"Name", item.Name},
|
||||
label := item.Name
|
||||
record := []kv{
|
||||
{"Name", label},
|
||||
{"Provider", item.Provider},
|
||||
})
|
||||
}
|
||||
if item.ID == settingsResp.SearchProviderID {
|
||||
label += " [current]"
|
||||
record[0].value = label
|
||||
currentRecords = append(currentRecords, record)
|
||||
continue
|
||||
}
|
||||
otherRecords = append(otherRecords, record)
|
||||
}
|
||||
return formatItems(records), nil
|
||||
currentRecords = append(currentRecords, otherRecords...)
|
||||
records := currentRecords
|
||||
return formatLimitedItems(records, defaultListLimit, "Use /search current to inspect the active provider."), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "current",
|
||||
Usage: "current - Show the current search provider",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
settingsResp, err := h.getBotSettings(cc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return formatKV([]kv{{"Search Provider", h.resolveSearchProviderName(cc, settingsResp.SearchProviderID)}}), nil
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
@@ -38,7 +67,11 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
if len(cc.Args) < 1 {
|
||||
return "Usage: /search set <name>", nil
|
||||
}
|
||||
if h.settingsService == nil {
|
||||
return "Settings service is not available.", nil
|
||||
}
|
||||
name := cc.Args[0]
|
||||
before, _ := h.getBotSettings(cc)
|
||||
items, err := h.searchProvService.List(cc.Ctx, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -51,7 +84,7 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Search provider set to %q.", item.Name), nil
|
||||
return formatChangedValue("Search provider", h.resolveSearchProviderName(cc, before.SearchProviderID), item.Name), nil
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("Search provider %q not found.", name), nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -106,6 +107,13 @@ func settingsUpdateUsage() string {
|
||||
"- --heartbeat_model_id <id>"
|
||||
}
|
||||
|
||||
func (h *Handler) getBotSettings(cc CommandContext) (settings.Settings, error) {
|
||||
if h.settingsService == nil {
|
||||
return settings.Settings{}, errors.New("settings service is not available")
|
||||
}
|
||||
return h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||
}
|
||||
|
||||
// resolveModelName resolves a model UUID to "model_name (provider_name)".
|
||||
func (h *Handler) resolveModelName(cc CommandContext, modelID string) string {
|
||||
if modelID == "" {
|
||||
|
||||
+94
-57
@@ -6,7 +6,9 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
func (h *Handler) buildStatusGroup() *CommandGroup {
|
||||
@@ -16,78 +18,113 @@ func (h *Handler) buildStatusGroup() *CommandGroup {
|
||||
Name: "show",
|
||||
Usage: "show - Show current session status",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if strings.TrimSpace(cc.SessionID) == "" {
|
||||
return "No active session found for this conversation.", nil
|
||||
}
|
||||
return h.renderSessionStatus(cc, cc.SessionID, "current conversation")
|
||||
},
|
||||
})
|
||||
g.Register(SubCommand{
|
||||
Name: "latest",
|
||||
Usage: "latest - Show the latest session status for this bot",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Session info is not available.", nil
|
||||
}
|
||||
botUUID, err := parseBotUUID(cc.BotID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "No active session found.", nil
|
||||
return "No session found for this bot.", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, sessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("count messages: %w", err)
|
||||
}
|
||||
|
||||
var usedTokens int64
|
||||
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, sessionID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("get usage: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
usedTokens = latestUsage
|
||||
}
|
||||
|
||||
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, sessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get cache: %w", err)
|
||||
}
|
||||
|
||||
var contextWindowStr string
|
||||
if h.settingsService != nil {
|
||||
s, sErr := h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||
if sErr == nil && s.ChatModelID != "" && h.modelsService != nil {
|
||||
m, mErr := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
|
||||
if mErr == nil && m.Config.ContextWindow != nil {
|
||||
contextWindowStr = formatTokens(int64(*m.Config.ContextWindow))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cacheHitRate float64
|
||||
if cacheRow.TotalInputTokens > 0 {
|
||||
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
|
||||
}
|
||||
|
||||
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, sessionID)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("Session Status:\n\n")
|
||||
fmt.Fprintf(&b, "- Messages: %d\n", msgCount)
|
||||
if contextWindowStr != "" {
|
||||
fmt.Fprintf(&b, "- Context: %s / %s\n", formatTokens(usedTokens), contextWindowStr)
|
||||
} else {
|
||||
fmt.Fprintf(&b, "- Context: %s\n", formatTokens(usedTokens))
|
||||
}
|
||||
fmt.Fprintf(&b, "- Cache Hit Rate: %.1f%%\n", cacheHitRate)
|
||||
fmt.Fprintf(&b, "- Cache Read: %s\n", formatTokens(cacheRow.CacheReadTokens))
|
||||
fmt.Fprintf(&b, "- Cache Write: %s\n", formatTokens(cacheRow.CacheWriteTokens))
|
||||
|
||||
if len(skills) > 0 {
|
||||
fmt.Fprintf(&b, "- Skills: %s\n", strings.Join(skills, ", "))
|
||||
}
|
||||
|
||||
return strings.TrimRight(b.String(), "\n"), nil
|
||||
return h.renderSessionStatus(cc, sessionID.String(), "latest bot session")
|
||||
},
|
||||
})
|
||||
return g
|
||||
}
|
||||
|
||||
func (h *Handler) renderSessionStatus(cc CommandContext, sessionID string, scope string) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Session info is not available.", nil
|
||||
}
|
||||
pgSessionID, err := parseCommandUUID(sessionID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, pgSessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("count messages: %w", err)
|
||||
}
|
||||
|
||||
var usedTokens int64
|
||||
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, pgSessionID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("get usage: %w", err)
|
||||
}
|
||||
if err == nil {
|
||||
usedTokens = latestUsage
|
||||
}
|
||||
|
||||
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, pgSessionID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get cache: %w", err)
|
||||
}
|
||||
|
||||
var cacheHitRate float64
|
||||
if cacheRow.TotalInputTokens > 0 {
|
||||
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
|
||||
}
|
||||
|
||||
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, pgSessionID)
|
||||
|
||||
contextUsage := formatTokens(usedTokens)
|
||||
if contextWindow := h.resolveContextWindow(cc); contextWindow != "" {
|
||||
contextUsage = contextUsage + " / " + contextWindow
|
||||
}
|
||||
|
||||
pairs := []kv{
|
||||
{"Scope", scope},
|
||||
{"Session ID", sessionID},
|
||||
{"Messages", strconv.FormatInt(msgCount, 10)},
|
||||
{"Context", contextUsage},
|
||||
{"Cache Hit Rate", fmt.Sprintf("%.1f%%", cacheHitRate)},
|
||||
{"Cache Read", formatTokens(cacheRow.CacheReadTokens)},
|
||||
{"Cache Write", formatTokens(cacheRow.CacheWriteTokens)},
|
||||
}
|
||||
if len(skills) > 0 {
|
||||
pairs = append(pairs, kv{"Skills", strings.Join(skills, ", ")})
|
||||
}
|
||||
return formatKV(pairs), nil
|
||||
}
|
||||
|
||||
func (h *Handler) resolveContextWindow(cc CommandContext) string {
|
||||
if h.settingsService == nil || h.modelsService == nil {
|
||||
return ""
|
||||
}
|
||||
s, err := h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||
if err != nil || s.ChatModelID == "" {
|
||||
return ""
|
||||
}
|
||||
m, err := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
|
||||
if err != nil || m.Config.ContextWindow == nil {
|
||||
return ""
|
||||
}
|
||||
return formatTokens(int64(*m.Config.ContextWindow))
|
||||
}
|
||||
|
||||
func parseCommandUUID(id string) (pgtype.UUID, error) {
|
||||
parsed, err := uuid.Parse(strings.TrimSpace(id))
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid uuid: %w", err)
|
||||
}
|
||||
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
|
||||
}
|
||||
|
||||
func formatTokens(n int64) string {
|
||||
if n >= 1_000_000 {
|
||||
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
||||
|
||||
@@ -18,6 +18,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
|
||||
Name: "summary",
|
||||
Usage: "summary - Token usage summary (last 7 days)",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Usage info is not available.", nil
|
||||
}
|
||||
botUUID, err := parseBotUUID(cc.BotID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -89,6 +92,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
|
||||
Name: "by-model",
|
||||
Usage: "by-model - Token usage grouped by model",
|
||||
Handler: func(cc CommandContext) (string, error) {
|
||||
if h.queries == nil {
|
||||
return "Usage info is not available.", nil
|
||||
}
|
||||
botUUID, err := parseBotUUID(cc.BotID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
Reference in New Issue
Block a user