mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(command): improve slash command UX (#361)
Make slash commands easier to navigate in chat by splitting help into levels, compacting list output, and surfacing current selections for model, search, memory, and browser settings. Also route /status to the active conversation session and add an access inspector so users can understand their current command and ACL context.
This commit is contained in:
@@ -604,6 +604,7 @@ func provideChannelRouter(
|
|||||||
emailOutboxService,
|
emailOutboxService,
|
||||||
heartbeatService,
|
heartbeatService,
|
||||||
queries,
|
queries,
|
||||||
|
aclService,
|
||||||
&commandSkillLoaderAdapter{handler: containerdHandler},
|
&commandSkillLoaderAdapter{handler: containerdHandler},
|
||||||
&commandContainerFSAdapter{manager: manager},
|
&commandContainerFSAdapter{manager: manager},
|
||||||
))
|
))
|
||||||
@@ -779,6 +780,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
|
|||||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
|
||||||
|
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
|
||||||
|
if err != nil {
|
||||||
|
return inbound.SessionResult{}, err
|
||||||
|
}
|
||||||
|
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
||||||
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -492,6 +492,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc
|
|||||||
emailOutboxService,
|
emailOutboxService,
|
||||||
heartbeatService,
|
heartbeatService,
|
||||||
queries,
|
queries,
|
||||||
|
aclService,
|
||||||
&commandSkillLoaderAdapter{handler: containerdHandler},
|
&commandSkillLoaderAdapter{handler: containerdHandler},
|
||||||
&commandContainerFSAdapter{manager: manager},
|
&commandContainerFSAdapter{manager: manager},
|
||||||
))
|
))
|
||||||
@@ -877,6 +878,14 @@ func (a *sessionEnsurerAdapter) EnsureActiveSession(ctx context.Context, botID,
|
|||||||
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *sessionEnsurerAdapter) GetActiveSession(ctx context.Context, routeID string) (inbound.SessionResult, error) {
|
||||||
|
sess, err := a.svc.GetActiveForRoute(ctx, routeID)
|
||||||
|
if err != nil {
|
||||||
|
return inbound.SessionResult{}, err
|
||||||
|
}
|
||||||
|
return inbound.SessionResult{ID: sess.ID, Type: sess.Type}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
func (a *sessionEnsurerAdapter) CreateNewSession(ctx context.Context, botID, routeID, channelType, sessionType string) (inbound.SessionResult, error) {
|
||||||
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
sess, err := a.svc.CreateNewSession(ctx, botID, routeID, channelType, sessionType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ type ttsModelResolver interface {
|
|||||||
// SessionEnsurer resolves or creates an active session for a route.
|
// SessionEnsurer resolves or creates an active session for a route.
|
||||||
type SessionEnsurer interface {
|
type SessionEnsurer interface {
|
||||||
EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error)
|
EnsureActiveSession(ctx context.Context, botID, routeID, channelType string) (SessionResult, error)
|
||||||
|
GetActiveSession(ctx context.Context, routeID string) (SessionResult, error)
|
||||||
// CreateNewSession always creates a fresh session and sets it as the
|
// CreateNewSession always creates a fresh session and sets it as the
|
||||||
// active session for the given route, replacing any previous one.
|
// active session for the given route, replacing any previous one.
|
||||||
// sessionType defaults to "chat" if empty.
|
// sessionType defaults to "chat" if empty.
|
||||||
@@ -298,11 +299,23 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
|||||||
if isStopCommand(cmdText) && isDirectedAtBot(msg) {
|
if isStopCommand(cmdText) && isDirectedAtBot(msg) {
|
||||||
return p.handleStopCommand(ctx, cfg, msg, sender, identity)
|
return p.handleStopCommand(ctx, cfg, msg, sender, identity)
|
||||||
}
|
}
|
||||||
|
if isStatusCommand(cmdText) && isDirectedAtBot(msg) {
|
||||||
|
return p.handleStatusCommand(ctx, cfg, msg, sender, identity)
|
||||||
|
}
|
||||||
|
|
||||||
// Skip generic command handler for mode-prefix commands (/btw, /now, /next)
|
// Skip generic command handler for mode-prefix commands (/btw, /now, /next)
|
||||||
// so they pass through to mode detection below.
|
// so they pass through to mode detection below.
|
||||||
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
|
if p.commandHandler != nil && p.commandHandler.IsCommand(cmdText) && !IsModeCommand(cmdText) && isDirectedAtBot(msg) {
|
||||||
reply, err := p.commandHandler.Execute(ctx, strings.TrimSpace(identity.BotID), strings.TrimSpace(identity.ChannelIdentityID), cmdText)
|
reply, err := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
|
||||||
|
BotID: strings.TrimSpace(identity.BotID),
|
||||||
|
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
|
||||||
|
UserID: strings.TrimSpace(identity.UserID),
|
||||||
|
Text: cmdText,
|
||||||
|
ChannelType: msg.Channel.String(),
|
||||||
|
ConversationType: strings.TrimSpace(msg.Conversation.Type),
|
||||||
|
ConversationID: strings.TrimSpace(msg.Conversation.ID),
|
||||||
|
ThreadID: extractThreadID(msg),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
reply = "Error: " + err.Error()
|
reply = "Error: " + err.Error()
|
||||||
}
|
}
|
||||||
@@ -2489,6 +2502,18 @@ func isNewSessionCommand(cmdText string) bool {
|
|||||||
return parsed.Resource == "new"
|
return parsed.Resource == "new"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStatusCommand(cmdText string) bool {
|
||||||
|
extracted := command.ExtractCommandText(cmdText)
|
||||||
|
if extracted == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
parsed, err := command.Parse(extracted)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return parsed.Resource == "status"
|
||||||
|
}
|
||||||
|
|
||||||
// resolveNewSessionType determines the session type for /new command.
|
// resolveNewSessionType determines the session type for /new command.
|
||||||
// /new chat → chat, /new discuss → discuss, /new (no arg) → default by context.
|
// /new chat → chat, /new discuss → discuss, /new (no arg) → default by context.
|
||||||
// WebUI (local channel) always defaults to chat.
|
// WebUI (local channel) always defaults to chat.
|
||||||
@@ -2611,3 +2636,83 @@ func (p *ChannelInboundProcessor) handleNewSessionCommand(
|
|||||||
Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)},
|
Message: channel.Message{Text: fmt.Sprintf("New %s conversation started.", modeLabel)},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ChannelInboundProcessor) handleStatusCommand(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg channel.ChannelConfig,
|
||||||
|
msg channel.InboundMessage,
|
||||||
|
sender channel.StreamReplySender,
|
||||||
|
identity InboundIdentity,
|
||||||
|
) error {
|
||||||
|
target := strings.TrimSpace(msg.ReplyTarget)
|
||||||
|
if target == "" {
|
||||||
|
return errors.New("reply target missing for /status command")
|
||||||
|
}
|
||||||
|
if p.routeResolver == nil {
|
||||||
|
return sender.Send(ctx, channel.OutboundMessage{
|
||||||
|
Target: target,
|
||||||
|
Message: channel.Message{Text: "Error: route resolver not configured."},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if p.commandHandler == nil {
|
||||||
|
return sender.Send(ctx, channel.OutboundMessage{
|
||||||
|
Target: target,
|
||||||
|
Message: channel.Message{Text: "Error: command handler not configured."},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
threadID := extractThreadID(msg)
|
||||||
|
routeMetadata := buildRouteMetadata(msg, identity)
|
||||||
|
p.enrichConversationAvatar(ctx, cfg, msg, routeMetadata)
|
||||||
|
resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{
|
||||||
|
BotID: identity.BotID,
|
||||||
|
Platform: msg.Channel.String(),
|
||||||
|
ConversationID: msg.Conversation.ID,
|
||||||
|
ThreadID: threadID,
|
||||||
|
ConversationType: msg.Conversation.Type,
|
||||||
|
ChannelIdentityID: identity.UserID,
|
||||||
|
ChannelConfigID: identity.ChannelConfigID,
|
||||||
|
ReplyTarget: target,
|
||||||
|
Metadata: routeMetadata,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if p.logger != nil {
|
||||||
|
p.logger.Warn("resolve route for /status command failed", slog.Any("error", err))
|
||||||
|
}
|
||||||
|
return sender.Send(ctx, channel.OutboundMessage{
|
||||||
|
Target: target,
|
||||||
|
Message: channel.Message{Text: "Error: failed to resolve conversation route."},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := ""
|
||||||
|
if p.sessionEnsurer != nil {
|
||||||
|
sess, sessErr := p.sessionEnsurer.GetActiveSession(ctx, resolved.RouteID)
|
||||||
|
if sessErr == nil {
|
||||||
|
sessionID = strings.TrimSpace(sess.ID)
|
||||||
|
} else if p.logger != nil {
|
||||||
|
p.logger.Debug("resolve active session for /status command failed", slog.Any("error", sessErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, execErr := p.commandHandler.ExecuteWithInput(ctx, command.ExecuteInput{
|
||||||
|
BotID: strings.TrimSpace(identity.BotID),
|
||||||
|
ChannelIdentityID: strings.TrimSpace(identity.ChannelIdentityID),
|
||||||
|
UserID: strings.TrimSpace(identity.UserID),
|
||||||
|
Text: rawTextForCommand(msg, strings.TrimSpace(msg.Message.PlainText())),
|
||||||
|
ChannelType: msg.Channel.String(),
|
||||||
|
ConversationType: strings.TrimSpace(msg.Conversation.Type),
|
||||||
|
ConversationID: strings.TrimSpace(msg.Conversation.ID),
|
||||||
|
ThreadID: threadID,
|
||||||
|
RouteID: strings.TrimSpace(resolved.RouteID),
|
||||||
|
SessionID: sessionID,
|
||||||
|
})
|
||||||
|
if execErr != nil {
|
||||||
|
reply = "Error: " + execErr.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
return sender.Send(ctx, channel.OutboundMessage{
|
||||||
|
Target: target,
|
||||||
|
Message: channel.Message{Text: reply},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,11 +13,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/acl"
|
"github.com/memohai/memoh/internal/acl"
|
||||||
"github.com/memohai/memoh/internal/channel"
|
"github.com/memohai/memoh/internal/channel"
|
||||||
"github.com/memohai/memoh/internal/channel/identities"
|
"github.com/memohai/memoh/internal/channel/identities"
|
||||||
"github.com/memohai/memoh/internal/channel/route"
|
"github.com/memohai/memoh/internal/channel/route"
|
||||||
|
"github.com/memohai/memoh/internal/command"
|
||||||
"github.com/memohai/memoh/internal/conversation"
|
"github.com/memohai/memoh/internal/conversation"
|
||||||
|
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||||
"github.com/memohai/memoh/internal/media"
|
"github.com/memohai/memoh/internal/media"
|
||||||
messagepkg "github.com/memohai/memoh/internal/message"
|
messagepkg "github.com/memohai/memoh/internal/message"
|
||||||
"github.com/memohai/memoh/internal/schedule"
|
"github.com/memohai/memoh/internal/schedule"
|
||||||
@@ -185,6 +189,71 @@ type fakeChatACL struct {
|
|||||||
lastReq acl.EvaluateRequest
|
lastReq acl.EvaluateRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeSessionEnsurer struct {
|
||||||
|
activeSession SessionResult
|
||||||
|
activeErr error
|
||||||
|
lastRouteID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSessionEnsurer) EnsureActiveSession(_ context.Context, _, routeID, _ string) (SessionResult, error) {
|
||||||
|
f.lastRouteID = routeID
|
||||||
|
if f.activeErr != nil {
|
||||||
|
return SessionResult{}, f.activeErr
|
||||||
|
}
|
||||||
|
return f.activeSession, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSessionEnsurer) GetActiveSession(_ context.Context, routeID string) (SessionResult, error) {
|
||||||
|
f.lastRouteID = routeID
|
||||||
|
if f.activeErr != nil {
|
||||||
|
return SessionResult{}, f.activeErr
|
||||||
|
}
|
||||||
|
return f.activeSession, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSessionEnsurer) CreateNewSession(_ context.Context, _, routeID, _, _ string) (SessionResult, error) {
|
||||||
|
f.lastRouteID = routeID
|
||||||
|
if f.activeErr != nil {
|
||||||
|
return SessionResult{}, f.activeErr
|
||||||
|
}
|
||||||
|
return f.activeSession, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCommandQueries struct {
|
||||||
|
messageCount int64
|
||||||
|
usage int64
|
||||||
|
cacheRow dbsqlc.GetSessionCacheStatsRow
|
||||||
|
skills []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
|
||||||
|
return pgtype.UUID{}, errors.New("unexpected latest session lookup")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||||
|
return f.messageCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||||
|
return f.usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
|
||||||
|
return f.cacheRow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
|
||||||
|
return f.skills, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) {
|
func (f *fakeChatACL) Evaluate(_ context.Context, req acl.EvaluateRequest) (bool, error) {
|
||||||
f.calls++
|
f.calls++
|
||||||
f.lastReq = req
|
f.lastReq = req
|
||||||
@@ -548,6 +617,72 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChannelInboundProcessorStatusUsesRouteSession(t *testing.T) {
|
||||||
|
channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-status"}}
|
||||||
|
policySvc := &fakePolicyService{}
|
||||||
|
chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-status", RouteID: "route-status"}}
|
||||||
|
gateway := &fakeChatGateway{}
|
||||||
|
processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, policySvc, nil, "", 0)
|
||||||
|
processor.SetSessionEnsurer(&fakeSessionEnsurer{
|
||||||
|
activeSession: SessionResult{ID: "11111111-1111-1111-1111-111111111111", Type: "chat"},
|
||||||
|
})
|
||||||
|
processor.SetCommandHandler(command.NewHandler(
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&fakeCommandQueries{
|
||||||
|
messageCount: 9,
|
||||||
|
usage: 512,
|
||||||
|
cacheRow: dbsqlc.GetSessionCacheStatsRow{
|
||||||
|
CacheReadTokens: 64,
|
||||||
|
CacheWriteTokens: 32,
|
||||||
|
TotalInputTokens: 512,
|
||||||
|
},
|
||||||
|
skills: []string{"search"},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
))
|
||||||
|
sender := &fakeReplySender{}
|
||||||
|
|
||||||
|
cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("discord")}
|
||||||
|
msg := channel.InboundMessage{
|
||||||
|
BotID: "bot-1",
|
||||||
|
Channel: channel.ChannelType("discord"),
|
||||||
|
Message: channel.Message{Text: "/status"},
|
||||||
|
ReplyTarget: "discord:status",
|
||||||
|
Sender: channel.Identity{SubjectID: "user-1"},
|
||||||
|
Conversation: channel.Conversation{
|
||||||
|
ID: "conv-status",
|
||||||
|
Type: channel.ConversationTypePrivate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(sender.sent) != 1 {
|
||||||
|
t.Fatalf("expected one status reply, got %d", len(sender.sent))
|
||||||
|
}
|
||||||
|
if !strings.Contains(sender.sent[0].Message.Text, "- Scope: current conversation") {
|
||||||
|
t.Fatalf("expected current conversation scope, got %q", sender.sent[0].Message.Text)
|
||||||
|
}
|
||||||
|
if !strings.Contains(sender.sent[0].Message.Text, "- Session ID: 11111111-1111-1111-1111-111111111111") {
|
||||||
|
t.Fatalf("expected active route session in reply, got %q", sender.sent[0].Message.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) {
|
func TestBuildInboundQueryAttachmentOnlyReturnsEmpty(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
package command
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/memohai/memoh/internal/acl"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) buildAccessGroup() *CommandGroup {
|
||||||
|
g := newCommandGroup("access", "Inspect identity and permission context")
|
||||||
|
g.DefaultAction = "show"
|
||||||
|
g.Register(SubCommand{
|
||||||
|
Name: "show",
|
||||||
|
Usage: "show - Show current identity, write access, and chat ACL context",
|
||||||
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
writeAccess := "no"
|
||||||
|
if cc.Role == "owner" {
|
||||||
|
writeAccess = "yes"
|
||||||
|
}
|
||||||
|
|
||||||
|
pairs := []kv{
|
||||||
|
{"Channel Identity", fallbackValue(cc.ChannelIdentityID)},
|
||||||
|
{"Linked User", fallbackValue(cc.UserID)},
|
||||||
|
{"Bot Role", fallbackValue(cc.Role)},
|
||||||
|
{"Write Commands", writeAccess},
|
||||||
|
{"Channel", fallbackValue(cc.ChannelType)},
|
||||||
|
{"Conversation Type", fallbackValue(cc.ConversationType)},
|
||||||
|
{"Conversation ID", fallbackValue(cc.ConversationID)},
|
||||||
|
{"Thread ID", fallbackValue(cc.ThreadID)},
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cc.RouteID) != "" {
|
||||||
|
pairs = append(pairs, kv{"Route ID", cc.RouteID})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cc.SessionID) != "" {
|
||||||
|
pairs = append(pairs, kv{"Session ID", cc.SessionID})
|
||||||
|
}
|
||||||
|
|
||||||
|
aclStatus := "unavailable"
|
||||||
|
if h.aclEvaluator != nil && strings.TrimSpace(cc.ChannelType) != "" {
|
||||||
|
allowed, err := h.aclEvaluator.Evaluate(cc.Ctx, acl.EvaluateRequest{
|
||||||
|
BotID: cc.BotID,
|
||||||
|
ChannelIdentityID: cc.ChannelIdentityID,
|
||||||
|
ChannelType: cc.ChannelType,
|
||||||
|
SourceScope: acl.SourceScope{
|
||||||
|
ConversationType: cc.ConversationType,
|
||||||
|
ConversationID: cc.ConversationID,
|
||||||
|
ThreadID: cc.ThreadID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
aclStatus = "error: " + err.Error()
|
||||||
|
case allowed:
|
||||||
|
aclStatus = "allow"
|
||||||
|
default:
|
||||||
|
aclStatus = "deny"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pairs = append(pairs, kv{"Chat ACL", aclStatus})
|
||||||
|
|
||||||
|
return formatKV(pairs), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func fallbackValue(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return "(none)"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatChangedValue(label, before, after string) string {
|
||||||
|
return fmt.Sprintf("%s: %s -> %s", label, fallbackValue(before), fallbackValue(after))
|
||||||
|
}
|
||||||
@@ -13,6 +13,9 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
|||||||
Name: "list",
|
Name: "list",
|
||||||
Usage: "list - List all browser contexts",
|
Usage: "list - List all browser contexts",
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.browserCtxService == nil {
|
||||||
|
return "Browser context service is not available.", nil
|
||||||
|
}
|
||||||
items, err := h.browserCtxService.List(cc.Ctx)
|
items, err := h.browserCtxService.List(cc.Ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -20,13 +23,37 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
|||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
return "No browser contexts found.", nil
|
return "No browser contexts found.", nil
|
||||||
}
|
}
|
||||||
records := make([][]kv, 0, len(items))
|
settingsResp, _ := h.getBotSettings(cc)
|
||||||
|
currentRecords := make([][]kv, 0, 1)
|
||||||
|
otherRecords := make([][]kv, 0, len(items))
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
records = append(records, []kv{
|
label := item.Name
|
||||||
{"Name", item.Name},
|
record := []kv{{"Name", label}}
|
||||||
})
|
if item.ID == settingsResp.BrowserContextID {
|
||||||
|
label += " [current]"
|
||||||
|
record[0].value = label
|
||||||
|
currentRecords = append(currentRecords, record)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
otherRecords = append(otherRecords, record)
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
currentRecords = append(currentRecords, otherRecords...)
|
||||||
|
records := currentRecords
|
||||||
|
return formatLimitedItems(records, defaultListLimit, "Use /browser current to inspect the active context."), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
g.Register(SubCommand{
|
||||||
|
Name: "current",
|
||||||
|
Usage: "current - Show the current browser context",
|
||||||
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
|
settingsResp, err := h.getBotSettings(cc)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return formatKV([]kv{{"Browser Context", h.resolveBrowserContextName(cc, settingsResp.BrowserContextID)}}), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
@@ -37,7 +64,11 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
|||||||
if len(cc.Args) < 1 {
|
if len(cc.Args) < 1 {
|
||||||
return "Usage: /browser set <name>", nil
|
return "Usage: /browser set <name>", nil
|
||||||
}
|
}
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
name := cc.Args[0]
|
name := cc.Args[0]
|
||||||
|
before, _ := h.getBotSettings(cc)
|
||||||
items, err := h.browserCtxService.List(cc.Ctx)
|
items, err := h.browserCtxService.List(cc.Ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -50,7 +81,7 @@ func (h *Handler) buildBrowserGroup() *CommandGroup {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Browser context set to %q.", item.Name), nil
|
return formatChangedValue("Browser context", h.resolveBrowserContextName(cc, before.BrowserContextID), item.Name), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Browser context %q not found.", name), nil
|
return fmt.Sprintf("Browser context %q not found.", name), nil
|
||||||
|
|||||||
@@ -16,5 +16,6 @@ func (h *Handler) buildRegistry() *Registry {
|
|||||||
r.RegisterGroup(h.buildSkillGroup())
|
r.RegisterGroup(h.buildSkillGroup())
|
||||||
r.RegisterGroup(h.buildFSGroup())
|
r.RegisterGroup(h.buildFSGroup())
|
||||||
r.RegisterGroup(h.buildStatusGroup())
|
r.RegisterGroup(h.buildStatusGroup())
|
||||||
|
r.RegisterGroup(h.buildAccessGroup())
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
|||||||
{"Provider", item.Provider},
|
{"Provider", item.Provider},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
return formatLimitedItems(records, defaultListLimit, "Use /email bindings to inspect bot bindings."), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
@@ -46,7 +46,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
|||||||
{"Permissions", perms},
|
{"Permissions", perms},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
return formatLimitedItems(records, defaultListLimit, "Use /email outbox to inspect recent sends."), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
@@ -70,7 +70,7 @@ func (h *Handler) buildEmailGroup() *CommandGroup {
|
|||||||
{"Sent", item.SentAt.Format("01-02 15:04")},
|
{"Sent", item.SentAt.Format("01-02 15:04")},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
return formatLimitedItems(records, 10, "Use the Web UI for older outbox entries."), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return g
|
return g
|
||||||
|
|||||||
@@ -6,19 +6,10 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultListLimit = 12
|
||||||
|
|
||||||
// formatItems renders a list of records as a Markdown-style list.
|
// formatItems renders a list of records as a Markdown-style list.
|
||||||
// Each record is a slice of kv pairs; the first pair's value is used as the
|
// Each record is rendered on a single line so long lists stay readable in IM.
|
||||||
// bullet title, and subsequent pairs are indented beneath it.
|
|
||||||
//
|
|
||||||
// Example output:
|
|
||||||
//
|
|
||||||
// - mybot
|
|
||||||
// Description: A helpful assistant
|
|
||||||
// ID: abc123
|
|
||||||
//
|
|
||||||
// - another
|
|
||||||
// Description: Something else
|
|
||||||
// ID: def456
|
|
||||||
func formatItems(items [][]kv) string {
|
func formatItems(items [][]kv) string {
|
||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
return ""
|
return ""
|
||||||
@@ -31,14 +22,41 @@ func formatItems(items [][]kv) string {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
b.WriteByte('\n')
|
b.WriteByte('\n')
|
||||||
}
|
}
|
||||||
fmt.Fprintf(&b, "- %s\n", record[0].value)
|
fmt.Fprintf(&b, "- %s", record[0].value)
|
||||||
|
extras := make([]string, 0, len(record)-1)
|
||||||
for _, pair := range record[1:] {
|
for _, pair := range record[1:] {
|
||||||
fmt.Fprintf(&b, " %s: %s\n", pair.key, pair.value)
|
if strings.TrimSpace(pair.value) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
extras = append(extras, fmt.Sprintf("%s: %s", pair.key, pair.value))
|
||||||
|
}
|
||||||
|
if len(extras) > 0 {
|
||||||
|
fmt.Fprintf(&b, " | %s", strings.Join(extras, " | "))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatLimitedItems(items [][]kv, limit int, hint string) string {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = defaultListLimit
|
||||||
|
}
|
||||||
|
total := len(items)
|
||||||
|
if total <= limit {
|
||||||
|
return formatItems(items)
|
||||||
|
}
|
||||||
|
shown := items[:limit]
|
||||||
|
result := formatItems(shown)
|
||||||
|
suffix := fmt.Sprintf("Showing %d of %d items.", len(shown), total)
|
||||||
|
if strings.TrimSpace(hint) != "" {
|
||||||
|
suffix += " " + strings.TrimSpace(hint)
|
||||||
|
}
|
||||||
|
return result + "\n\n" + suffix
|
||||||
|
}
|
||||||
|
|
||||||
// formatKV renders key-value pairs as a simple Markdown list.
|
// formatKV renders key-value pairs as a simple Markdown list.
|
||||||
//
|
//
|
||||||
// Example output:
|
// Example output:
|
||||||
|
|||||||
+57
-13
@@ -4,10 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/bots"
|
"github.com/memohai/memoh/internal/bots"
|
||||||
"github.com/memohai/memoh/internal/browsercontexts"
|
"github.com/memohai/memoh/internal/browsercontexts"
|
||||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
|
||||||
emailpkg "github.com/memohai/memoh/internal/email"
|
emailpkg "github.com/memohai/memoh/internal/email"
|
||||||
"github.com/memohai/memoh/internal/heartbeat"
|
"github.com/memohai/memoh/internal/heartbeat"
|
||||||
"github.com/memohai/memoh/internal/mcp"
|
"github.com/memohai/memoh/internal/mcp"
|
||||||
@@ -56,13 +56,28 @@ type Handler struct {
|
|||||||
emailService *emailpkg.Service
|
emailService *emailpkg.Service
|
||||||
emailOutboxService *emailpkg.OutboxService
|
emailOutboxService *emailpkg.OutboxService
|
||||||
heartbeatService *heartbeat.Service
|
heartbeatService *heartbeat.Service
|
||||||
queries *dbsqlc.Queries
|
queries CommandQueries
|
||||||
|
aclEvaluator AccessEvaluator
|
||||||
skillLoader SkillLoader
|
skillLoader SkillLoader
|
||||||
containerFS ContainerFS
|
containerFS ContainerFS
|
||||||
|
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteInput carries the caller identity and channel context for command execution.
|
||||||
|
type ExecuteInput struct {
|
||||||
|
BotID string
|
||||||
|
ChannelIdentityID string
|
||||||
|
UserID string
|
||||||
|
Text string
|
||||||
|
ChannelType string
|
||||||
|
ConversationType string
|
||||||
|
ConversationID string
|
||||||
|
ThreadID string
|
||||||
|
RouteID string
|
||||||
|
SessionID string
|
||||||
|
}
|
||||||
|
|
||||||
// NewHandler creates a Handler with all required services.
|
// NewHandler creates a Handler with all required services.
|
||||||
func NewHandler(
|
func NewHandler(
|
||||||
log *slog.Logger,
|
log *slog.Logger,
|
||||||
@@ -78,7 +93,8 @@ func NewHandler(
|
|||||||
emailService *emailpkg.Service,
|
emailService *emailpkg.Service,
|
||||||
emailOutboxService *emailpkg.OutboxService,
|
emailOutboxService *emailpkg.OutboxService,
|
||||||
heartbeatService *heartbeat.Service,
|
heartbeatService *heartbeat.Service,
|
||||||
queries *dbsqlc.Queries,
|
queries CommandQueries,
|
||||||
|
aclEvaluator AccessEvaluator,
|
||||||
skillLoader SkillLoader,
|
skillLoader SkillLoader,
|
||||||
containerFS ContainerFS,
|
containerFS ContainerFS,
|
||||||
) *Handler {
|
) *Handler {
|
||||||
@@ -99,6 +115,7 @@ func NewHandler(
|
|||||||
emailOutboxService: emailOutboxService,
|
emailOutboxService: emailOutboxService,
|
||||||
heartbeatService: heartbeatService,
|
heartbeatService: heartbeatService,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
|
aclEvaluator: aclEvaluator,
|
||||||
skillLoader: skillLoader,
|
skillLoader: skillLoader,
|
||||||
containerFS: containerFS,
|
containerFS: containerFS,
|
||||||
logger: log.With(slog.String("component", "command")),
|
logger: log.With(slog.String("component", "command")),
|
||||||
@@ -140,7 +157,16 @@ func (h *Handler) IsCommand(text string) bool {
|
|||||||
|
|
||||||
// Execute parses and runs a slash command, returning the text reply.
|
// Execute parses and runs a slash command, returning the text reply.
|
||||||
func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) {
|
func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text string) (string, error) {
|
||||||
cmdText := ExtractCommandText(text)
|
return h.ExecuteWithInput(ctx, ExecuteInput{
|
||||||
|
BotID: botID,
|
||||||
|
ChannelIdentityID: channelIdentityID,
|
||||||
|
Text: text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteWithInput parses and runs a slash command with channel/session context.
|
||||||
|
func (h *Handler) ExecuteWithInput(ctx context.Context, input ExecuteInput) (string, error) {
|
||||||
|
cmdText := ExtractCommandText(input.Text)
|
||||||
if cmdText == "" {
|
if cmdText == "" {
|
||||||
return h.registry.GlobalHelp(), nil
|
return h.registry.GlobalHelp(), nil
|
||||||
}
|
}
|
||||||
@@ -151,12 +177,16 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
|
|||||||
|
|
||||||
// Resolve the user's role in this bot.
|
// Resolve the user's role in this bot.
|
||||||
role := ""
|
role := ""
|
||||||
if h.roleResolver != nil && channelIdentityID != "" {
|
roleIdentityID := input.ChannelIdentityID
|
||||||
r, err := h.roleResolver.GetMemberRole(ctx, botID, channelIdentityID)
|
if strings.TrimSpace(input.UserID) != "" {
|
||||||
|
roleIdentityID = strings.TrimSpace(input.UserID)
|
||||||
|
}
|
||||||
|
if h.roleResolver != nil && roleIdentityID != "" {
|
||||||
|
r, err := h.roleResolver.GetMemberRole(ctx, input.BotID, roleIdentityID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("failed to resolve member role",
|
h.logger.Warn("failed to resolve member role",
|
||||||
slog.String("bot_id", botID),
|
slog.String("bot_id", input.BotID),
|
||||||
slog.String("channel_identity_id", channelIdentityID),
|
slog.String("role_identity_id", roleIdentityID),
|
||||||
slog.Any("error", err),
|
slog.Any("error", err),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@@ -165,15 +195,29 @@ func (h *Handler) Execute(ctx context.Context, botID, channelIdentityID, text st
|
|||||||
}
|
}
|
||||||
|
|
||||||
cc := CommandContext{
|
cc := CommandContext{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
BotID: botID,
|
BotID: input.BotID,
|
||||||
Role: role,
|
Role: role,
|
||||||
Args: parsed.Args,
|
Args: parsed.Args,
|
||||||
|
ChannelIdentityID: strings.TrimSpace(input.ChannelIdentityID),
|
||||||
|
UserID: strings.TrimSpace(input.UserID),
|
||||||
|
ChannelType: strings.TrimSpace(input.ChannelType),
|
||||||
|
ConversationType: strings.TrimSpace(input.ConversationType),
|
||||||
|
ConversationID: strings.TrimSpace(input.ConversationID),
|
||||||
|
ThreadID: strings.TrimSpace(input.ThreadID),
|
||||||
|
RouteID: strings.TrimSpace(input.RouteID),
|
||||||
|
SessionID: strings.TrimSpace(input.SessionID),
|
||||||
}
|
}
|
||||||
|
|
||||||
// /help
|
// /help
|
||||||
if parsed.Resource == "help" {
|
if parsed.Resource == "help" {
|
||||||
return h.registry.GlobalHelp(), nil
|
if parsed.Action == "" {
|
||||||
|
return h.registry.GlobalHelp(), nil
|
||||||
|
}
|
||||||
|
if len(parsed.Args) == 0 {
|
||||||
|
return h.registry.GroupHelp(parsed.Action), nil
|
||||||
|
}
|
||||||
|
return h.registry.ActionHelp(parsed.Action, parsed.Args[0]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Top-level commands (e.g. /new) are handled by the channel inbound
|
// Top-level commands (e.g. /new) are handled by the channel inbound
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
|
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||||
"github.com/memohai/memoh/internal/mcp"
|
"github.com/memohai/memoh/internal/mcp"
|
||||||
"github.com/memohai/memoh/internal/schedule"
|
"github.com/memohai/memoh/internal/schedule"
|
||||||
"github.com/memohai/memoh/internal/settings"
|
"github.com/memohai/memoh/internal/settings"
|
||||||
@@ -25,9 +29,58 @@ type fakeScheduleService struct {
|
|||||||
items []schedule.Schedule
|
items []schedule.Schedule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeCommandQueries struct {
|
||||||
|
latestSessionID pgtype.UUID
|
||||||
|
latestSessionErr error
|
||||||
|
messageCount int64
|
||||||
|
latestUsage int64
|
||||||
|
latestUsageErr error
|
||||||
|
cacheRow dbsqlc.GetSessionCacheStatsRow
|
||||||
|
cacheErr error
|
||||||
|
skills []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) GetLatestSessionIDByBot(_ context.Context, _ pgtype.UUID) (pgtype.UUID, error) {
|
||||||
|
return f.latestSessionID, f.latestSessionErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) CountMessagesBySession(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||||
|
return f.messageCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) GetLatestAssistantUsage(_ context.Context, _ pgtype.UUID) (int64, error) {
|
||||||
|
if f.latestUsageErr != nil {
|
||||||
|
return 0, f.latestUsageErr
|
||||||
|
}
|
||||||
|
return f.latestUsage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) GetSessionCacheStats(_ context.Context, _ pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error) {
|
||||||
|
if f.cacheErr != nil {
|
||||||
|
return dbsqlc.GetSessionCacheStatsRow{}, f.cacheErr
|
||||||
|
}
|
||||||
|
return f.cacheRow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCommandQueries) GetSessionUsedSkills(_ context.Context, _ pgtype.UUID) ([]string, error) {
|
||||||
|
return f.skills, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeCommandQueries) GetTokenUsageByDayAndType(_ context.Context, _ dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeCommandQueries) GetTokenUsageByModel(_ context.Context, _ dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// newTestHandler creates a Handler with nil services for use in tests.
|
// newTestHandler creates a Handler with nil services for use in tests.
|
||||||
func newTestHandler(roleResolver MemberRoleResolver) *Handler {
|
func newTestHandler(roleResolver MemberRoleResolver) *Handler {
|
||||||
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestHandlerWithQueries(roleResolver MemberRoleResolver, queries CommandQueries) *Handler {
|
||||||
|
return NewHandler(nil, roleResolver, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, queries, nil, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- tests ---
|
// --- tests ---
|
||||||
@@ -73,6 +126,42 @@ func TestExecute_Help(t *testing.T) {
|
|||||||
if !strings.Contains(result, "Available commands") {
|
if !strings.Contains(result, "Available commands") {
|
||||||
t.Errorf("expected help text, got: %s", result)
|
t.Errorf("expected help text, got: %s", result)
|
||||||
}
|
}
|
||||||
|
if strings.Contains(result, "set-heartbeat") {
|
||||||
|
t.Errorf("top-level help should not expand nested actions, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "- /model - Manage bot models") {
|
||||||
|
t.Errorf("expected top-level model entry, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecute_HelpGroup(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||||
|
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "/model - Manage bot models") {
|
||||||
|
t.Errorf("expected group help, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "- set - Set the chat model [owner]") {
|
||||||
|
t.Errorf("expected compact action summary, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecute_HelpAction(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||||
|
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/help model set")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "Usage: /model set <model_id> | <provider_name> <model_name>") {
|
||||||
|
t.Errorf("expected action usage, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "Access: owner only") {
|
||||||
|
t.Errorf("expected owner hint, got: %s", result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecute_UnknownCommand(t *testing.T) {
|
func TestExecute_UnknownCommand(t *testing.T) {
|
||||||
@@ -207,8 +296,8 @@ func TestFormatItems(t *testing.T) {
|
|||||||
if !strings.Contains(result, "- foo") {
|
if !strings.Contains(result, "- foo") {
|
||||||
t.Errorf("expected '- foo' bullet, got: %s", result)
|
t.Errorf("expected '- foo' bullet, got: %s", result)
|
||||||
}
|
}
|
||||||
if !strings.Contains(result, " Type: bar") {
|
if !strings.Contains(result, "- foo | Type: bar") {
|
||||||
t.Errorf("expected indented 'Type: bar', got: %s", result)
|
t.Errorf("expected compact line entry, got: %s", result)
|
||||||
}
|
}
|
||||||
if !strings.Contains(result, "- longname") {
|
if !strings.Contains(result, "- longname") {
|
||||||
t.Errorf("expected '- longname' bullet, got: %s", result)
|
t.Errorf("expected '- longname' bullet, got: %s", result)
|
||||||
@@ -255,7 +344,7 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
|
|||||||
for _, group := range []string{
|
for _, group := range []string{
|
||||||
"schedule", "mcp", "settings",
|
"schedule", "mcp", "settings",
|
||||||
"model", "memory", "search", "browser", "usage",
|
"model", "memory", "search", "browser", "usage",
|
||||||
"email", "heartbeat", "skill", "fs",
|
"email", "heartbeat", "skill", "fs", "access",
|
||||||
} {
|
} {
|
||||||
if !strings.Contains(help, "/"+group) {
|
if !strings.Contains(help, "/"+group) {
|
||||||
t.Errorf("missing /%s in global help", group)
|
t.Errorf("missing /%s in global help", group)
|
||||||
@@ -263,6 +352,86 @@ func TestGlobalHelp_AllGroups(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithInput_Access(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := newTestHandler(&fakeRoleResolver{role: "owner"})
|
||||||
|
result, err := h.ExecuteWithInput(context.Background(), ExecuteInput{
|
||||||
|
BotID: "bot-1",
|
||||||
|
ChannelIdentityID: "channel-id-1",
|
||||||
|
UserID: "user-id-1",
|
||||||
|
Text: "/access",
|
||||||
|
ChannelType: "discord",
|
||||||
|
ConversationType: "thread",
|
||||||
|
ConversationID: "conv-1",
|
||||||
|
ThreadID: "thread-1",
|
||||||
|
RouteID: "route-1",
|
||||||
|
SessionID: "session-1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "- Channel Identity: channel-id-1") {
|
||||||
|
t.Errorf("expected channel identity in access output, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "- Write Commands: yes") {
|
||||||
|
t.Errorf("expected write access in access output, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecute_StatusLatest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
sessionUUID := pgtype.UUID{}
|
||||||
|
copy(sessionUUID.Bytes[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||||
|
sessionUUID.Valid = true
|
||||||
|
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
|
||||||
|
latestSessionID: sessionUUID,
|
||||||
|
messageCount: 42,
|
||||||
|
latestUsage: 1200,
|
||||||
|
cacheRow: dbsqlc.GetSessionCacheStatsRow{
|
||||||
|
CacheReadTokens: 300,
|
||||||
|
CacheWriteTokens: 150,
|
||||||
|
TotalInputTokens: 1200,
|
||||||
|
},
|
||||||
|
skills: []string{"search", "browser"},
|
||||||
|
})
|
||||||
|
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "- Scope: latest bot session") {
|
||||||
|
t.Errorf("expected latest scope, got: %s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "- Messages: 42") {
|
||||||
|
t.Errorf("expected message count, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecute_StatusLatestNoRows(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{
|
||||||
|
latestSessionErr: pgx.ErrNoRows,
|
||||||
|
})
|
||||||
|
result, err := h.Execute(context.Background(), "11111111-1111-1111-1111-111111111111", "user-1", "/status latest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "No session found for this bot.") {
|
||||||
|
t.Errorf("expected no session message, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecute_StatusShowWithoutSession(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := newTestHandlerWithQueries(&fakeRoleResolver{role: "owner"}, &fakeCommandQueries{})
|
||||||
|
result, err := h.Execute(context.Background(), "bot-1", "user-1", "/status")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "No active session found for this conversation.") {
|
||||||
|
t.Errorf("expected route-aware no session message, got: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Verify write commands are tagged with [owner] in usage.
|
// Verify write commands are tagged with [owner] in usage.
|
||||||
func TestUsage_OwnerTag(t *testing.T) {
|
func TestUsage_OwnerTag(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (h *Handler) buildHeartbeatGroup() *CommandGroup {
|
|||||||
}
|
}
|
||||||
records = append(records, rec)
|
records = append(records, rec)
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
return formatLimitedItems(records, 10, "Use the Web UI for older heartbeat logs."), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return g
|
return g
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
package command
|
package command
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
|
"github.com/memohai/memoh/internal/acl"
|
||||||
|
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||||
|
)
|
||||||
|
|
||||||
// Skill represents a single skill loaded from a bot's container.
|
// Skill represents a single skill loaded from a bot's container.
|
||||||
type Skill struct {
|
type Skill struct {
|
||||||
@@ -25,3 +32,20 @@ type ContainerFS interface {
|
|||||||
ListDir(ctx context.Context, botID, path string) ([]FSEntry, error)
|
ListDir(ctx context.Context, botID, path string) ([]FSEntry, error)
|
||||||
ReadFile(ctx context.Context, botID, path string) (string, error)
|
ReadFile(ctx context.Context, botID, path string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CommandQueries captures the sqlc methods used by slash commands.
|
||||||
|
// *dbsqlc.Queries satisfies this interface directly.
|
||||||
|
type CommandQueries interface {
|
||||||
|
GetLatestSessionIDByBot(ctx context.Context, botID pgtype.UUID) (pgtype.UUID, error)
|
||||||
|
CountMessagesBySession(ctx context.Context, sessionID pgtype.UUID) (int64, error)
|
||||||
|
GetLatestAssistantUsage(ctx context.Context, sessionID pgtype.UUID) (int64, error)
|
||||||
|
GetSessionCacheStats(ctx context.Context, sessionID pgtype.UUID) (dbsqlc.GetSessionCacheStatsRow, error)
|
||||||
|
GetSessionUsedSkills(ctx context.Context, sessionID pgtype.UUID) ([]string, error)
|
||||||
|
GetTokenUsageByDayAndType(ctx context.Context, arg dbsqlc.GetTokenUsageByDayAndTypeParams) ([]dbsqlc.GetTokenUsageByDayAndTypeRow, error)
|
||||||
|
GetTokenUsageByModel(ctx context.Context, arg dbsqlc.GetTokenUsageByModelParams) ([]dbsqlc.GetTokenUsageByModelRow, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessEvaluator checks whether the current channel context may trigger chat.
|
||||||
|
type AccessEvaluator interface {
|
||||||
|
Evaluate(ctx context.Context, req acl.EvaluateRequest) (bool, error)
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func (h *Handler) buildMCPGroup() *CommandGroup {
|
|||||||
{"Status", item.Status},
|
{"Status", item.Status},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
return formatLimitedItems(records, defaultListLimit, "Use /mcp get <name> for full details."), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
|||||||
Name: "list",
|
Name: "list",
|
||||||
Usage: "list - List all memory providers",
|
Usage: "list - List all memory providers",
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.memProvService == nil {
|
||||||
|
return "Memory provider service is not available.", nil
|
||||||
|
}
|
||||||
items, err := h.memProvService.List(cc.Ctx)
|
items, err := h.memProvService.List(cc.Ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -20,18 +23,44 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
|||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
return "No memory providers found.", nil
|
return "No memory providers found.", nil
|
||||||
}
|
}
|
||||||
records := make([][]kv, 0, len(items))
|
settingsResp, _ := h.getBotSettings(cc)
|
||||||
|
currentRecords := make([][]kv, 0, 1)
|
||||||
|
otherRecords := make([][]kv, 0, len(items))
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
def := ""
|
def := ""
|
||||||
if item.IsDefault {
|
if item.IsDefault {
|
||||||
def = " (default)"
|
def = " (default)"
|
||||||
}
|
}
|
||||||
records = append(records, []kv{
|
label := item.Name + def
|
||||||
{"Name", item.Name + def},
|
record := []kv{
|
||||||
|
{"Name", label},
|
||||||
{"Provider", item.Provider},
|
{"Provider", item.Provider},
|
||||||
})
|
}
|
||||||
|
if item.ID == settingsResp.MemoryProviderID {
|
||||||
|
label += " [current]"
|
||||||
|
record[0].value = label
|
||||||
|
currentRecords = append(currentRecords, record)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
otherRecords = append(otherRecords, record)
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
currentRecords = append(currentRecords, otherRecords...)
|
||||||
|
records := currentRecords
|
||||||
|
return formatLimitedItems(records, defaultListLimit, "Use /memory current to inspect the active provider."), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
g.Register(SubCommand{
|
||||||
|
Name: "current",
|
||||||
|
Usage: "current - Show the current memory provider",
|
||||||
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
|
settingsResp, err := h.getBotSettings(cc)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return formatKV([]kv{{"Memory Provider", h.resolveMemoryProviderName(cc, settingsResp.MemoryProviderID)}}), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
@@ -42,7 +71,11 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
|||||||
if len(cc.Args) < 1 {
|
if len(cc.Args) < 1 {
|
||||||
return "Usage: /memory set <name>", nil
|
return "Usage: /memory set <name>", nil
|
||||||
}
|
}
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
name := cc.Args[0]
|
name := cc.Args[0]
|
||||||
|
before, _ := h.getBotSettings(cc)
|
||||||
items, err := h.memProvService.List(cc.Ctx)
|
items, err := h.memProvService.List(cc.Ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -55,7 +88,7 @@ func (h *Handler) buildMemoryGroup() *CommandGroup {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Memory provider set to %q.", item.Name), nil
|
return formatChangedValue("Memory provider", h.resolveMemoryProviderName(cc, before.MemoryProviderID), item.Name), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Memory provider %q not found.", name), nil
|
return fmt.Sprintf("Memory provider %q not found.", name), nil
|
||||||
|
|||||||
+147
-13
@@ -1,7 +1,9 @@
|
|||||||
package command
|
package command
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/memohai/memoh/internal/models"
|
"github.com/memohai/memoh/internal/models"
|
||||||
@@ -12,36 +14,81 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
|||||||
g := newCommandGroup("model", "Manage bot models")
|
g := newCommandGroup("model", "Manage bot models")
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
Name: "list",
|
Name: "list",
|
||||||
Usage: "list - List all available chat models",
|
Usage: "list [provider_name] - List available chat models",
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.modelsService == nil {
|
||||||
|
return "Model service is not available.", nil
|
||||||
|
}
|
||||||
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
|
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
filterProvider := ""
|
||||||
|
if len(cc.Args) > 0 {
|
||||||
|
filterProvider = strings.TrimSpace(strings.Join(cc.Args, " "))
|
||||||
|
}
|
||||||
|
items = h.filterModelsByProvider(cc, items, filterProvider)
|
||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
|
if filterProvider != "" {
|
||||||
|
return fmt.Sprintf("No chat models found for provider %q.", filterProvider), nil
|
||||||
|
}
|
||||||
return "No chat models found.", nil
|
return "No chat models found.", nil
|
||||||
}
|
}
|
||||||
|
settingsResp, _ := h.getBotSettings(cc)
|
||||||
|
sort.SliceStable(items, func(i, j int) bool {
|
||||||
|
return modelSortRank(items[i], settingsResp) < modelSortRank(items[j], settingsResp)
|
||||||
|
})
|
||||||
records := make([][]kv, 0, len(items))
|
records := make([][]kv, 0, len(items))
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
provName := h.resolveProviderName(cc, item.ProviderID)
|
provName := h.resolveProviderName(cc, item.ProviderID)
|
||||||
|
label := item.Name
|
||||||
|
markers := modelMarkers(item.ID, settingsResp)
|
||||||
|
if len(markers) > 0 {
|
||||||
|
label += " [" + strings.Join(markers, ", ") + "]"
|
||||||
|
}
|
||||||
records = append(records, []kv{
|
records = append(records, []kv{
|
||||||
{"Model", item.Name},
|
{"Model", label},
|
||||||
{"Provider", provName},
|
{"Provider", provName},
|
||||||
{"Model ID", item.ModelID},
|
{"Model ID", item.ModelID},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
hint := "Use /model current to inspect active selections."
|
||||||
|
if filterProvider == "" {
|
||||||
|
hint = "Use /model list <provider_name> to narrow results."
|
||||||
|
}
|
||||||
|
return formatLimitedItems(records, defaultListLimit, hint), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
g.Register(SubCommand{
|
||||||
|
Name: "current",
|
||||||
|
Usage: "current - Show current chat and heartbeat models",
|
||||||
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
|
settingsResp, err := h.getBotSettings(cc)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return formatKV([]kv{
|
||||||
|
{"Chat Model", h.resolveModelName(cc, settingsResp.ChatModelID)},
|
||||||
|
{"Heartbeat Model", h.resolveModelName(cc, settingsResp.HeartbeatModelID)},
|
||||||
|
}), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
Name: "set",
|
Name: "set",
|
||||||
Usage: "set <provider_name> <model_name> - Set the chat model",
|
Usage: "set <model_id> | <provider_name> <model_name> - Set the chat model",
|
||||||
IsWrite: true,
|
IsWrite: true,
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
if len(cc.Args) < 2 {
|
if len(cc.Args) < 1 {
|
||||||
return "Usage: /model set <provider_name> <model_name>", nil
|
return "Usage: /model set <model_id> | <provider_name> <model_name>", nil
|
||||||
}
|
}
|
||||||
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
|
before, _ := h.getBotSettings(cc)
|
||||||
|
modelResp, err := h.findModelForSelection(cc, cc.Args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -51,18 +98,22 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Chat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
|
return formatChangedValue("Chat model", h.resolveModelName(cc, before.ChatModelID), h.resolveModelName(cc, modelResp.ID)), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
Name: "set-heartbeat",
|
Name: "set-heartbeat",
|
||||||
Usage: "set-heartbeat <provider_name> <model_name> - Set the heartbeat model",
|
Usage: "set-heartbeat <model_id> | <provider_name> <model_name> - Set the heartbeat model",
|
||||||
IsWrite: true,
|
IsWrite: true,
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
if len(cc.Args) < 2 {
|
if len(cc.Args) < 1 {
|
||||||
return "Usage: /model set-heartbeat <provider_name> <model_name>", nil
|
return "Usage: /model set-heartbeat <model_id> | <provider_name> <model_name>", nil
|
||||||
}
|
}
|
||||||
modelResp, err := h.findModelByProviderAndName(cc, cc.Args[0], cc.Args[1])
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
|
before, _ := h.getBotSettings(cc)
|
||||||
|
modelResp, err := h.findModelForSelection(cc, cc.Args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -72,7 +123,7 @@ func (h *Handler) buildModelGroup() *CommandGroup {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Heartbeat model set to %s (%s).", modelResp.Name, cc.Args[0]), nil
|
return formatChangedValue("Heartbeat model", h.resolveModelName(cc, before.HeartbeatModelID), h.resolveModelName(cc, modelResp.ID)), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return g
|
return g
|
||||||
@@ -105,3 +156,86 @@ func (h *Handler) findModelByProviderAndName(cc CommandContext, providerName, mo
|
|||||||
}
|
}
|
||||||
return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName)
|
return models.GetResponse{}, fmt.Errorf("model %q not found under provider %q", modelName, providerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) findModelForSelection(cc CommandContext, args []string) (models.GetResponse, error) {
|
||||||
|
if h.modelsService == nil {
|
||||||
|
return models.GetResponse{}, errors.New("model service is not available")
|
||||||
|
}
|
||||||
|
if len(args) == 0 {
|
||||||
|
return models.GetResponse{}, errors.New("model identifier is required")
|
||||||
|
}
|
||||||
|
if len(args) == 1 {
|
||||||
|
return h.findModelByIDOrName(cc, args[0])
|
||||||
|
}
|
||||||
|
return h.findModelByProviderAndName(cc, args[0], strings.Join(args[1:], " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) findModelByIDOrName(cc CommandContext, target string) (models.GetResponse, error) {
|
||||||
|
items, err := h.modelsService.ListByType(cc.Ctx, models.ModelTypeChat)
|
||||||
|
if err != nil {
|
||||||
|
return models.GetResponse{}, err
|
||||||
|
}
|
||||||
|
target = strings.TrimSpace(target)
|
||||||
|
if target == "" {
|
||||||
|
return models.GetResponse{}, errors.New("model identifier is required")
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.EqualFold(item.ModelID, target) {
|
||||||
|
return item, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
matches := make([]models.GetResponse, 0, 4)
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.EqualFold(item.Name, target) {
|
||||||
|
matches = append(matches, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch len(matches) {
|
||||||
|
case 0:
|
||||||
|
return models.GetResponse{}, fmt.Errorf("model %q not found", target)
|
||||||
|
case 1:
|
||||||
|
return matches[0], nil
|
||||||
|
default:
|
||||||
|
choices := make([]string, 0, len(matches))
|
||||||
|
for _, item := range matches {
|
||||||
|
choices = append(choices, fmt.Sprintf("%s/%s", h.resolveProviderName(cc, item.ProviderID), item.ModelID))
|
||||||
|
}
|
||||||
|
return models.GetResponse{}, fmt.Errorf("model %q is ambiguous; use a model ID or provider-qualified name (%s)", target, strings.Join(choices, ", "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) filterModelsByProvider(cc CommandContext, items []models.GetResponse, providerName string) []models.GetResponse {
|
||||||
|
providerName = strings.TrimSpace(providerName)
|
||||||
|
if providerName == "" {
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
filtered := make([]models.GetResponse, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.EqualFold(h.resolveProviderName(cc, item.ProviderID), providerName) {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelMarkers(modelID string, settingsResp settings.Settings) []string {
|
||||||
|
var markers []string
|
||||||
|
if modelID == settingsResp.ChatModelID {
|
||||||
|
markers = append(markers, "chat")
|
||||||
|
}
|
||||||
|
if modelID == settingsResp.HeartbeatModelID {
|
||||||
|
markers = append(markers, "heartbeat")
|
||||||
|
}
|
||||||
|
return markers
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelSortRank(model models.GetResponse, settingsResp settings.Settings) int {
|
||||||
|
switch len(modelMarkers(model.ID, settingsResp)) {
|
||||||
|
case 2:
|
||||||
|
return 0
|
||||||
|
case 1:
|
||||||
|
return 1
|
||||||
|
default:
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ func TestParse_Basic(t *testing.T) {
|
|||||||
args []string
|
args []string
|
||||||
}{
|
}{
|
||||||
{"/help", "help", "", nil},
|
{"/help", "help", "", nil},
|
||||||
|
{"/help model", "help", "model", nil},
|
||||||
|
{"/help model set", "help", "model", []string{"set"}},
|
||||||
{"/subagent list", "subagent", "list", nil},
|
{"/subagent list", "subagent", "list", nil},
|
||||||
{"/subagent get mybot", "subagent", "get", []string{"mybot"}},
|
{"/subagent get mybot", "subagent", "get", []string{"mybot"}},
|
||||||
{"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}},
|
{"/schedule create daily \"0 9 * * *\" Send report", "schedule", "create", []string{"daily", "0 9 * * *", "Send", "report"}},
|
||||||
|
|||||||
@@ -8,10 +8,18 @@ import (
|
|||||||
|
|
||||||
// CommandContext carries execution context for a sub-command.
|
// CommandContext carries execution context for a sub-command.
|
||||||
type CommandContext struct {
|
type CommandContext struct {
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
BotID string
|
BotID string
|
||||||
Role string // "owner", "admin", "member", or "" (guest)
|
Role string // "owner", "admin", "member", or "" (guest)
|
||||||
Args []string
|
Args []string
|
||||||
|
ChannelIdentityID string
|
||||||
|
UserID string
|
||||||
|
ChannelType string
|
||||||
|
ConversationType string
|
||||||
|
ConversationID string
|
||||||
|
ThreadID string
|
||||||
|
RouteID string
|
||||||
|
SessionID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubCommand describes a single sub-command within a resource group.
|
// SubCommand describes a single sub-command within a resource group.
|
||||||
@@ -50,15 +58,38 @@ func (g *CommandGroup) Usage() string {
|
|||||||
fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description)
|
fmt.Fprintf(&b, "/%s - %s\n", g.Name, g.Description)
|
||||||
for _, name := range g.order {
|
for _, name := range g.order {
|
||||||
sub := g.commands[name]
|
sub := g.commands[name]
|
||||||
perm := ""
|
desc := subSummary(sub)
|
||||||
if sub.IsWrite {
|
if sub.IsWrite {
|
||||||
perm = " [owner]"
|
desc += " [owner]"
|
||||||
}
|
}
|
||||||
fmt.Fprintf(&b, "- %s%s\n", sub.Usage, perm)
|
fmt.Fprintf(&b, "- %s\n", desc)
|
||||||
}
|
}
|
||||||
|
fmt.Fprintf(&b, "\nUse /help %s <action> for details.", g.Name)
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *CommandGroup) ActionHelp(action string) string {
|
||||||
|
sub, ok := g.commands[action]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Sprintf("Unknown action %q for /%s.\n\n%s", action, g.Name, g.Usage())
|
||||||
|
}
|
||||||
|
usage, summary := splitUsage(sub.Usage)
|
||||||
|
var b strings.Builder
|
||||||
|
fmt.Fprintf(&b, "/%s %s\n", g.Name, sub.Name)
|
||||||
|
if summary != "" {
|
||||||
|
fmt.Fprintf(&b, "- Summary: %s\n", summary)
|
||||||
|
}
|
||||||
|
if usage == "" {
|
||||||
|
usage = sub.Name
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, "- Usage: /%s %s\n", g.Name, usage)
|
||||||
|
if sub.IsWrite {
|
||||||
|
b.WriteString("- Access: owner only\n")
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, "- Tip: use /help %s to view sibling actions.", g.Name)
|
||||||
|
return strings.TrimRight(b.String(), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
// Registry holds all registered command groups.
|
// Registry holds all registered command groups.
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
groups map[string]*CommandGroup
|
groups map[string]*CommandGroup
|
||||||
@@ -83,12 +114,47 @@ func (r *Registry) GlobalHelp() string {
|
|||||||
b.WriteString("/help - Show this help message\n")
|
b.WriteString("/help - Show this help message\n")
|
||||||
b.WriteString("/new - Start a new conversation (resets session context)\n")
|
b.WriteString("/new - Start a new conversation (resets session context)\n")
|
||||||
b.WriteString("/stop - Stop the current generation\n\n")
|
b.WriteString("/stop - Stop the current generation\n\n")
|
||||||
for i, name := range r.order {
|
for _, name := range r.order {
|
||||||
if i > 0 {
|
|
||||||
b.WriteByte('\n')
|
|
||||||
}
|
|
||||||
group := r.groups[name]
|
group := r.groups[name]
|
||||||
b.WriteString(group.Usage())
|
fmt.Fprintf(&b, "- /%s - %s\n", group.Name, group.Description)
|
||||||
}
|
}
|
||||||
|
b.WriteString("\nUse /help <group> to view actions, e.g. /help model")
|
||||||
return strings.TrimRight(b.String(), "\n")
|
return strings.TrimRight(b.String(), "\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Registry) GroupHelp(name string) string {
|
||||||
|
group, ok := r.groups[name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Sprintf("Unknown command group: /%s\n\n%s", name, r.GlobalHelp())
|
||||||
|
}
|
||||||
|
return group.Usage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) ActionHelp(groupName, action string) string {
|
||||||
|
group, ok := r.groups[groupName]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Sprintf("Unknown command group: /%s\n\n%s", groupName, r.GlobalHelp())
|
||||||
|
}
|
||||||
|
return group.ActionHelp(action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitUsage(usage string) (commandUsage string, summary string) {
|
||||||
|
usage = strings.TrimSpace(usage)
|
||||||
|
if usage == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(usage, " - ", 2)
|
||||||
|
commandUsage = strings.TrimSpace(parts[0])
|
||||||
|
if len(parts) > 1 {
|
||||||
|
summary = strings.TrimSpace(parts[1])
|
||||||
|
}
|
||||||
|
return commandUsage, summary
|
||||||
|
}
|
||||||
|
|
||||||
|
func subSummary(sub SubCommand) string {
|
||||||
|
usage, summary := splitUsage(sub.Usage)
|
||||||
|
if summary == "" {
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s - %s", sub.Name, summary)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
|||||||
Name: "list",
|
Name: "list",
|
||||||
Usage: "list - List all search providers",
|
Usage: "list - List all search providers",
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.searchProvService == nil {
|
||||||
|
return "Search provider service is not available.", nil
|
||||||
|
}
|
||||||
items, err := h.searchProvService.List(cc.Ctx, "")
|
items, err := h.searchProvService.List(cc.Ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -20,14 +23,40 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
|||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
return "No search providers found.", nil
|
return "No search providers found.", nil
|
||||||
}
|
}
|
||||||
records := make([][]kv, 0, len(items))
|
settingsResp, _ := h.getBotSettings(cc)
|
||||||
|
currentRecords := make([][]kv, 0, 1)
|
||||||
|
otherRecords := make([][]kv, 0, len(items))
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
records = append(records, []kv{
|
label := item.Name
|
||||||
{"Name", item.Name},
|
record := []kv{
|
||||||
|
{"Name", label},
|
||||||
{"Provider", item.Provider},
|
{"Provider", item.Provider},
|
||||||
})
|
}
|
||||||
|
if item.ID == settingsResp.SearchProviderID {
|
||||||
|
label += " [current]"
|
||||||
|
record[0].value = label
|
||||||
|
currentRecords = append(currentRecords, record)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
otherRecords = append(otherRecords, record)
|
||||||
}
|
}
|
||||||
return formatItems(records), nil
|
currentRecords = append(currentRecords, otherRecords...)
|
||||||
|
records := currentRecords
|
||||||
|
return formatLimitedItems(records, defaultListLimit, "Use /search current to inspect the active provider."), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
g.Register(SubCommand{
|
||||||
|
Name: "current",
|
||||||
|
Usage: "current - Show the current search provider",
|
||||||
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
|
settingsResp, err := h.getBotSettings(cc)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return formatKV([]kv{{"Search Provider", h.resolveSearchProviderName(cc, settingsResp.SearchProviderID)}}), nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
g.Register(SubCommand{
|
g.Register(SubCommand{
|
||||||
@@ -38,7 +67,11 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
|||||||
if len(cc.Args) < 1 {
|
if len(cc.Args) < 1 {
|
||||||
return "Usage: /search set <name>", nil
|
return "Usage: /search set <name>", nil
|
||||||
}
|
}
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return "Settings service is not available.", nil
|
||||||
|
}
|
||||||
name := cc.Args[0]
|
name := cc.Args[0]
|
||||||
|
before, _ := h.getBotSettings(cc)
|
||||||
items, err := h.searchProvService.List(cc.Ctx, "")
|
items, err := h.searchProvService.List(cc.Ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -51,7 +84,7 @@ func (h *Handler) buildSearchGroup() *CommandGroup {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Search provider set to %q.", item.Name), nil
|
return formatChangedValue("Search provider", h.resolveSearchProviderName(cc, before.SearchProviderID), item.Name), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("Search provider %q not found.", name), nil
|
return fmt.Sprintf("Search provider %q not found.", name), nil
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package command
|
package command
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -106,6 +107,13 @@ func settingsUpdateUsage() string {
|
|||||||
"- --heartbeat_model_id <id>"
|
"- --heartbeat_model_id <id>"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) getBotSettings(cc CommandContext) (settings.Settings, error) {
|
||||||
|
if h.settingsService == nil {
|
||||||
|
return settings.Settings{}, errors.New("settings service is not available")
|
||||||
|
}
|
||||||
|
return h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||||
|
}
|
||||||
|
|
||||||
// resolveModelName resolves a model UUID to "model_name (provider_name)".
|
// resolveModelName resolves a model UUID to "model_name (provider_name)".
|
||||||
func (h *Handler) resolveModelName(cc CommandContext, modelID string) string {
|
func (h *Handler) resolveModelName(cc CommandContext, modelID string) string {
|
||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
|
|||||||
+94
-57
@@ -6,7 +6,9 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *Handler) buildStatusGroup() *CommandGroup {
|
func (h *Handler) buildStatusGroup() *CommandGroup {
|
||||||
@@ -16,78 +18,113 @@ func (h *Handler) buildStatusGroup() *CommandGroup {
|
|||||||
Name: "show",
|
Name: "show",
|
||||||
Usage: "show - Show current session status",
|
Usage: "show - Show current session status",
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if strings.TrimSpace(cc.SessionID) == "" {
|
||||||
|
return "No active session found for this conversation.", nil
|
||||||
|
}
|
||||||
|
return h.renderSessionStatus(cc, cc.SessionID, "current conversation")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
g.Register(SubCommand{
|
||||||
|
Name: "latest",
|
||||||
|
Usage: "latest - Show the latest session status for this bot",
|
||||||
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.queries == nil {
|
||||||
|
return "Session info is not available.", nil
|
||||||
|
}
|
||||||
botUUID, err := parseBotUUID(cc.BotID)
|
botUUID, err := parseBotUUID(cc.BotID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID)
|
sessionID, err := h.queries.GetLatestSessionIDByBot(cc.Ctx, botUUID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
return "No active session found.", nil
|
return "No session found for this bot.", nil
|
||||||
}
|
}
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
return h.renderSessionStatus(cc, sessionID.String(), "latest bot session")
|
||||||
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("count messages: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var usedTokens int64
|
|
||||||
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, sessionID)
|
|
||||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
|
||||||
return "", fmt.Errorf("get usage: %w", err)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
usedTokens = latestUsage
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("get cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var contextWindowStr string
|
|
||||||
if h.settingsService != nil {
|
|
||||||
s, sErr := h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
|
||||||
if sErr == nil && s.ChatModelID != "" && h.modelsService != nil {
|
|
||||||
m, mErr := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
|
|
||||||
if mErr == nil && m.Config.ContextWindow != nil {
|
|
||||||
contextWindowStr = formatTokens(int64(*m.Config.ContextWindow))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var cacheHitRate float64
|
|
||||||
if cacheRow.TotalInputTokens > 0 {
|
|
||||||
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
|
|
||||||
}
|
|
||||||
|
|
||||||
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, sessionID)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString("Session Status:\n\n")
|
|
||||||
fmt.Fprintf(&b, "- Messages: %d\n", msgCount)
|
|
||||||
if contextWindowStr != "" {
|
|
||||||
fmt.Fprintf(&b, "- Context: %s / %s\n", formatTokens(usedTokens), contextWindowStr)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(&b, "- Context: %s\n", formatTokens(usedTokens))
|
|
||||||
}
|
|
||||||
fmt.Fprintf(&b, "- Cache Hit Rate: %.1f%%\n", cacheHitRate)
|
|
||||||
fmt.Fprintf(&b, "- Cache Read: %s\n", formatTokens(cacheRow.CacheReadTokens))
|
|
||||||
fmt.Fprintf(&b, "- Cache Write: %s\n", formatTokens(cacheRow.CacheWriteTokens))
|
|
||||||
|
|
||||||
if len(skills) > 0 {
|
|
||||||
fmt.Fprintf(&b, "- Skills: %s\n", strings.Join(skills, ", "))
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.TrimRight(b.String(), "\n"), nil
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) renderSessionStatus(cc CommandContext, sessionID string, scope string) (string, error) {
|
||||||
|
if h.queries == nil {
|
||||||
|
return "Session info is not available.", nil
|
||||||
|
}
|
||||||
|
pgSessionID, err := parseCommandUUID(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
msgCount, err := h.queries.CountMessagesBySession(cc.Ctx, pgSessionID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("count messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var usedTokens int64
|
||||||
|
latestUsage, err := h.queries.GetLatestAssistantUsage(cc.Ctx, pgSessionID)
|
||||||
|
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return "", fmt.Errorf("get usage: %w", err)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
usedTokens = latestUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheRow, err := h.queries.GetSessionCacheStats(cc.Ctx, pgSessionID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get cache: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cacheHitRate float64
|
||||||
|
if cacheRow.TotalInputTokens > 0 {
|
||||||
|
cacheHitRate = float64(cacheRow.CacheReadTokens) / float64(cacheRow.TotalInputTokens) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
skills, _ := h.queries.GetSessionUsedSkills(cc.Ctx, pgSessionID)
|
||||||
|
|
||||||
|
contextUsage := formatTokens(usedTokens)
|
||||||
|
if contextWindow := h.resolveContextWindow(cc); contextWindow != "" {
|
||||||
|
contextUsage = contextUsage + " / " + contextWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
pairs := []kv{
|
||||||
|
{"Scope", scope},
|
||||||
|
{"Session ID", sessionID},
|
||||||
|
{"Messages", strconv.FormatInt(msgCount, 10)},
|
||||||
|
{"Context", contextUsage},
|
||||||
|
{"Cache Hit Rate", fmt.Sprintf("%.1f%%", cacheHitRate)},
|
||||||
|
{"Cache Read", formatTokens(cacheRow.CacheReadTokens)},
|
||||||
|
{"Cache Write", formatTokens(cacheRow.CacheWriteTokens)},
|
||||||
|
}
|
||||||
|
if len(skills) > 0 {
|
||||||
|
pairs = append(pairs, kv{"Skills", strings.Join(skills, ", ")})
|
||||||
|
}
|
||||||
|
return formatKV(pairs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) resolveContextWindow(cc CommandContext) string {
|
||||||
|
if h.settingsService == nil || h.modelsService == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
s, err := h.settingsService.GetBot(cc.Ctx, cc.BotID)
|
||||||
|
if err != nil || s.ChatModelID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
m, err := h.modelsService.GetByID(cc.Ctx, s.ChatModelID)
|
||||||
|
if err != nil || m.Config.ContextWindow == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return formatTokens(int64(*m.Config.ContextWindow))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCommandUUID(id string) (pgtype.UUID, error) {
|
||||||
|
parsed, err := uuid.Parse(strings.TrimSpace(id))
|
||||||
|
if err != nil {
|
||||||
|
return pgtype.UUID{}, fmt.Errorf("invalid uuid: %w", err)
|
||||||
|
}
|
||||||
|
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func formatTokens(n int64) string {
|
func formatTokens(n int64) string {
|
||||||
if n >= 1_000_000 {
|
if n >= 1_000_000 {
|
||||||
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
|
|||||||
Name: "summary",
|
Name: "summary",
|
||||||
Usage: "summary - Token usage summary (last 7 days)",
|
Usage: "summary - Token usage summary (last 7 days)",
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.queries == nil {
|
||||||
|
return "Usage info is not available.", nil
|
||||||
|
}
|
||||||
botUUID, err := parseBotUUID(cc.BotID)
|
botUUID, err := parseBotUUID(cc.BotID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -89,6 +92,9 @@ func (h *Handler) buildUsageGroup() *CommandGroup {
|
|||||||
Name: "by-model",
|
Name: "by-model",
|
||||||
Usage: "by-model - Token usage grouped by model",
|
Usage: "by-model - Token usage grouped by model",
|
||||||
Handler: func(cc CommandContext) (string, error) {
|
Handler: func(cc CommandContext) (string, error) {
|
||||||
|
if h.queries == nil {
|
||||||
|
return "Usage info is not available.", nil
|
||||||
|
}
|
||||||
botUUID, err := parseBotUUID(cc.BotID)
|
botUUID, err := parseBotUUID(cc.BotID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
Reference in New Issue
Block a user