diff --git a/internal/channel/adapters/discord/discord.go b/internal/channel/adapters/discord/discord.go index c1a3044d..80e00473 100644 --- a/internal/channel/adapters/discord/discord.go +++ b/internal/channel/adapters/discord/discord.go @@ -108,6 +108,7 @@ func (*DiscordAdapter) Descriptor() channel.Descriptor { } func (a *DiscordAdapter) getOrCreateSession(token, configID string) (*discordgo.Session, error) { + channel.SetIMErrorSecrets("discord:"+configID, token) a.mu.RLock() session, ok := a.sessions[token] a.mu.RUnlock() diff --git a/internal/channel/adapters/discord/stream.go b/internal/channel/adapters/discord/stream.go index 9e56e521..5193a9dc 100644 --- a/internal/channel/adapters/discord/stream.go +++ b/internal/channel/adapters/discord/stream.go @@ -76,7 +76,7 @@ func (s *discordOutboundStream) Push(ctx context.Context, event channel.StreamEv return nil case channel.StreamEventError: - errText := strings.TrimSpace(event.Error) + errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error)) if errText == "" { return nil } diff --git a/internal/channel/adapters/discord/stream_test.go b/internal/channel/adapters/discord/stream_test.go new file mode 100644 index 00000000..dae64546 --- /dev/null +++ b/internal/channel/adapters/discord/stream_test.go @@ -0,0 +1,74 @@ +package discord + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "github.com/bwmarrin/discordgo" + + "github.com/memohai/memoh/internal/channel" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } + +func TestDiscordOutboundStream_PushErrorEventRedactsSecrets(t *testing.T) { + channel.ResetIMErrorSecretsForTest() + t.Cleanup(channel.ResetIMErrorSecretsForTest) + + const token = "discord-token-ABCDEFGHIJKLMNOPQRSTUVWXYZ" + channel.SetIMErrorSecrets("test", token) + prefixHalf := token[:len(token)/2] + + var sentBody string + session, err := discordgo.New("Bot test") + if err != nil { + t.Fatalf("create session: %v", err) + } + session.Client = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + body, _ := io.ReadAll(req.Body) + sentBody = string(body) + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"msg-1","channel_id":"ch-1"}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + } + return resp, nil + }), + } + + stream := &discordOutboundStream{ + adapter: &DiscordAdapter{}, + target: "ch-1", + session: session, + } + + err = stream.Push(context.Background(), channel.StreamEvent{ + Type: channel.StreamEventError, + Error: "request failed: " + prefixHalf, + }) + if err != nil { + t.Fatalf("push error event: %v", err) + } + + var payload map[string]any + if err := json.Unmarshal([]byte(sentBody), &payload); err != nil { + t.Fatalf("decode sent body: %v (body=%q)", err, sentBody) + } + content, _ := payload["content"].(string) + if strings.Contains(content, prefixHalf) { + t.Fatalf("expected prefix half to be redacted, got %q", content) + } + if !strings.Contains(content, "Error: ") { + t.Fatalf("expected error prefix, got %q", content) + } + if !strings.Contains(content, strings.Repeat("*", len(prefixHalf))) { + t.Fatalf("expected redaction mask, got %q", content) + } +} diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go index 7a0e87df..d96bec75 100644 --- a/internal/channel/adapters/feishu/config.go +++ b/internal/channel/adapters/feishu/config.go @@ -192,3 +192,12 @@ func (c Config) openBaseURL() string { } return lark.FeishuBaseUrl } + +func (c Config) registerIMErrorSecrets() { + channel.SetIMErrorSecrets("feishu:"+c.AppID, c.AppSecret, c.EncryptKey, c.VerificationToken) +} + +func (c Config) newClient() *lark.Client { + c.registerIMErrorSecrets() + return lark.NewClient(c.AppID, c.AppSecret, lark.WithOpenBaseUrl(c.openBaseURL())) +} diff --git a/internal/channel/adapters/feishu/directory.go b/internal/channel/adapters/feishu/directory.go index bfcbe084..b4023091 100644 --- a/internal/channel/adapters/feishu/directory.go +++ b/internal/channel/adapters/feishu/directory.go @@ -34,7 +34,7 @@ func (*FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, if err != nil { return nil, err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() pageSize := directoryLimit(query.Limit) req := larkcontact.NewListUserReqBuilder(). UserIdType(larkcontact.UserIdTypeOpenId). @@ -66,7 +66,7 @@ func (*FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, if err != nil { return nil, err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() pageSize := directoryLimit(query.Limit) var items []*larkim.ListChat if strings.TrimSpace(query.Query) != "" { @@ -115,7 +115,7 @@ func (*FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelC if chatID == "" { return nil, errors.New("feishu list group members: empty group id") } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() pageSize := directoryLimit(query.Limit) req := larkim.NewGetChatMembersReqBuilder(). ChatId(chatID). @@ -146,7 +146,7 @@ func (a *FeishuAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelCon if err != nil { return channel.DirectoryEntry{}, err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() input = strings.TrimSpace(input) switch kind { case channel.DirectoryEntryUser: diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index 833df38a..035e1b97 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -227,7 +227,7 @@ func (*FeishuAdapter) processingReactionGateway(cfg channel.ChannelConfig) (proc if err != nil { return nil, err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() gateway := &larkProcessingReactionGateway{api: client.Im.MessageReaction} return gateway, nil } @@ -291,7 +291,7 @@ func (*FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string]a if err != nil { return nil, "", err } - client := lark.NewClient(cfg.AppID, cfg.AppSecret, lark.WithOpenBaseUrl(cfg.openBaseURL())) + client := cfg.newClient() resp, err := client.Get(ctx, "/open-apis/bot/v3/info", nil, larkcore.AccessTokenTypeTenant) if err != nil { return nil, "", fmt.Errorf("feishu discover self: %w", err) @@ -466,6 +466,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, eventDispatcher.OnP2MessageReactionDeletedV1(func(_ context.Context, _ *larkim.P2MessageReactionDeletedV1) error { return nil }) + feishuCfg.registerIMErrorSecrets() return larkws.NewClient( feishuCfg.AppID, feishuCfg.AppSecret, @@ -526,7 +527,7 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg return err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() if len(msg.Message.Attachments) > 0 { for _, att := range msg.Message.Attachments { @@ -603,7 +604,7 @@ func (a *FeishuAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfi if err != nil { return nil, err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() select { case <-ctx.Done(): return nil, ctx.Err() @@ -855,7 +856,7 @@ func (*FeishuAdapter) ResolveAttachment(ctx context.Context, cfg channel.Channel if err != nil { return channel.AttachmentPayload{}, err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() resourceType := "file" if isFeishuImageAttachment(attachment) { diff --git a/internal/channel/adapters/feishu/sender_profile.go b/internal/channel/adapters/feishu/sender_profile.go index 057a9eb4..c76133d7 100644 --- a/internal/channel/adapters/feishu/sender_profile.go +++ b/internal/channel/adapters/feishu/sender_profile.go @@ -75,7 +75,7 @@ func (*FeishuAdapter) lookupSenderProfile(ctx context.Context, cfg channel.Chann if err != nil { return feishuSenderProfile{}, err } - client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) + client := feishuCfg.newClient() var lastErr error chatID = strings.TrimSpace(chatID) diff --git a/internal/channel/adapters/feishu/stream.go b/internal/channel/adapters/feishu/stream.go index 3e93e6df..28bceeb4 100644 --- a/internal/channel/adapters/feishu/stream.go +++ b/internal/channel/adapters/feishu/stream.go @@ -133,7 +133,7 @@ func (s *feishuOutboundStream) Push(ctx context.Context, event channel.StreamEve } return nil case channel.StreamEventError: - errText := strings.TrimSpace(event.Error) + errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error)) if errText == "" { return nil } diff --git a/internal/channel/adapters/qq/client.go b/internal/channel/adapters/qq/client.go index d61d358d..86ae92c3 100644 --- a/internal/channel/adapters/qq/client.go +++ b/internal/channel/adapters/qq/client.go @@ -13,6 +13,8 @@ import ( "strings" "sync" "time" + + "github.com/memohai/memoh/internal/channel" ) type qqClient struct { @@ -105,6 +107,7 @@ func (c *qqClient) accessToken(ctx context.Context) (string, error) { } c.token = token c.expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second) + channel.SetIMErrorSecrets("qq-token:"+c.appID, c.clientSecret, c.token) return c.token, nil } diff --git a/internal/channel/adapters/qq/qq.go b/internal/channel/adapters/qq/qq.go index 07787df9..cd99a4e2 100644 --- a/internal/channel/adapters/qq/qq.go +++ b/internal/channel/adapters/qq/qq.go @@ -224,6 +224,7 @@ func (*QQAdapter) ProcessingFailed(context.Context, channel.ChannelConfig, chann } func (a *QQAdapter) getOrCreateClient(cfg channel.ChannelConfig, parsed Config) *qqClient { + channel.SetIMErrorSecrets("qq:"+parsed.AppID, parsed.AppSecret) a.mu.Lock() defer a.mu.Unlock() diff --git a/internal/channel/adapters/qq/stream.go b/internal/channel/adapters/qq/stream.go index 9677f55f..24bfb9e6 100644 --- a/internal/channel/adapters/qq/stream.go +++ b/internal/channel/adapters/qq/stream.go @@ -3,6 +3,7 @@ package qq import ( "context" "errors" + "fmt" "strings" "sync" "sync/atomic" @@ -23,6 +24,11 @@ type qqOutboundStream struct { } func (a *QQAdapter) OpenStream(_ context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + parsed, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, fmt.Errorf("qq open stream: %w", err) + } + channel.SetIMErrorSecrets("qq:"+parsed.AppID, parsed.AppSecret) return &qqOutboundStream{ target: target, reply: opts.Reply, @@ -80,7 +86,7 @@ func (s *qqOutboundStream) Push(ctx context.Context, event channel.StreamEvent) s.mu.Unlock() return nil case channel.StreamEventError: - errText := strings.TrimSpace(event.Error) + errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error)) if errText == "" { return nil } diff --git a/internal/channel/adapters/qq/stream_test.go b/internal/channel/adapters/qq/stream_test.go index 57d0ad97..487964d0 100644 --- a/internal/channel/adapters/qq/stream_test.go +++ b/internal/channel/adapters/qq/stream_test.go @@ -2,6 +2,7 @@ package qq import ( "context" + "strings" "testing" "github.com/memohai/memoh/internal/channel" @@ -169,3 +170,32 @@ func TestQQOutboundStreamRejectsAfterClose(t *testing.T) { t.Fatal("expected closed error") } } + +func TestQQOutboundStreamErrorRedactsRegisteredTokenFragments(t *testing.T) { + channel.ResetIMErrorSecretsForTest() + t.Cleanup(channel.ResetIMErrorSecretsForTest) + + const token = "qq-token-ABCDEFGHIJKLMNOPQRSTUVWXYZ" + channel.SetIMErrorSecrets("test", token) + prefixHalf := token[:len(token)/2] + + var sent []channel.OutboundMessage + stream := &qqOutboundStream{ + target: "c2c:user-openid", + send: func(_ context.Context, msg channel.OutboundMessage) error { + sent = append(sent, msg) + return nil + }, + } + + err := stream.Push(context.Background(), channel.StreamEvent{Type: channel.StreamEventError, Error: "failed: " + prefixHalf}) + if err != nil { + t.Fatalf("push error: %v", err) + } + if len(sent) != 1 { + t.Fatalf("expected one outbound message, got %d", len(sent)) + } + if got := sent[0].Message.PlainText(); strings.Contains(got, prefixHalf) { + t.Fatalf("expected redacted token fragment, got %q", got) + } +} diff --git a/internal/channel/adapters/telegram/stream.go b/internal/channel/adapters/telegram/stream.go index e97d911e..d278fdff 100644 --- a/internal/channel/adapters/telegram/stream.go +++ b/internal/channel/adapters/telegram/stream.go @@ -439,7 +439,7 @@ func (s *telegramOutboundStream) pushFinal(ctx context.Context, event channel.St } func (s *telegramOutboundStream) pushError(ctx context.Context, event channel.StreamEvent) error { - errText := strings.TrimSpace(event.Error) + errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error)) if errText == "" { return nil } diff --git a/internal/channel/adapters/telegram/stream_test.go b/internal/channel/adapters/telegram/stream_test.go index 5121ec01..f2c49659 100644 --- a/internal/channel/adapters/telegram/stream_test.go +++ b/internal/channel/adapters/telegram/stream_test.go @@ -105,6 +105,56 @@ func TestTelegramOutboundStream_PushErrorEventEmptyNoOp(t *testing.T) { } } +func TestTelegramOutboundStream_PushErrorEventRedactsRegisteredTokenFragments(t *testing.T) { + channel.ResetIMErrorSecretsForTest() + t.Cleanup(channel.ResetIMErrorSecretsForTest) + + const botToken = "123456:ABCDEFGHIJKLMNOPQRSTUVWXYZ" + var sentText string + + adapter := NewTelegramAdapter(nil) + stream, err := adapter.OpenStream(context.Background(), channel.ChannelConfig{ + ID: "cfg-1", + Credentials: map[string]any{"botToken": botToken}, + }, "12345", channel.StreamOptions{Metadata: map[string]any{"conversation_type": "private"}}) + if err != nil { + t.Fatalf("open stream: %v", err) + } + s, ok := stream.(*telegramOutboundStream) + if !ok { + t.Fatalf("unexpected stream type %T", stream) + } + + origGetBot := getOrCreateBotForTest + origSendText := sendTextForTest + getOrCreateBotForTest = func(_ *TelegramAdapter, _, _ string) (*tgbotapi.BotAPI, error) { + return &tgbotapi.BotAPI{Token: botToken}, nil + } + sendTextForTest = func(_ *tgbotapi.BotAPI, _ string, text string, _ int, _ string) (int64, int, error) { + sentText = text + return 1, 1, nil + } + defer func() { + getOrCreateBotForTest = origGetBot + sendTextForTest = origSendText + }() + + prefixHalf := botToken[:len(botToken)/2] + err = s.Push(context.Background(), channel.StreamEvent{Type: channel.StreamEventError, Error: "request failed: " + prefixHalf}) + if err != nil { + t.Fatalf("push error event: %v", err) + } + if strings.Contains(sentText, prefixHalf) { + t.Fatalf("expected prefix half to be redacted, got %q", sentText) + } + if !strings.Contains(sentText, "Error: ") { + t.Fatalf("expected error prefix, got %q", sentText) + } + if !strings.Contains(sentText, strings.Repeat("*", len(prefixHalf))) { + t.Fatalf("expected redaction mask, got %q", sentText) + } +} + func TestTelegramOutboundStream_CloseContextCanceled(t *testing.T) { t.Parallel() diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index dc1165ba..ea857d09 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -82,6 +82,7 @@ func (a *TelegramAdapter) SetAssetOpener(opener assetOpener) { var getOrCreateBotForTest func(a *TelegramAdapter, token, configID string) (*tgbotapi.BotAPI, error) func (a *TelegramAdapter) getOrCreateBot(cfg Config, configID string) (*tgbotapi.BotAPI, error) { + channel.SetIMErrorSecrets("telegram:"+configID, cfg.BotToken) if getOrCreateBotForTest != nil { return getOrCreateBotForTest(a, cfg.BotToken, configID) } @@ -644,6 +645,11 @@ func (a *TelegramAdapter) OpenStream(ctx context.Context, cfg channel.ChannelCon if target == "" { return nil, errors.New("telegram target is required") } + telegramCfg, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, fmt.Errorf("telegram open stream: %w", err) + } + channel.SetIMErrorSecrets("telegram:"+cfg.ID, telegramCfg.BotToken) select { case <-ctx.Done(): return nil, ctx.Err() diff --git a/internal/channel/error_redaction.go b/internal/channel/error_redaction.go new file mode 100644 index 00000000..420ebd64 --- /dev/null +++ b/internal/channel/error_redaction.go @@ -0,0 +1,125 @@ +package channel + +import ( + "net/url" + "slices" + "sort" + "strings" + "sync" + "unicode/utf8" +) + +var imErrorRedactionRegistry = struct { + mu sync.RWMutex + groups map[string][]string // key → variants + cache []string // deduplicated, sorted longest-first +}{ + groups: map[string][]string{}, +} + +// SetIMErrorSecrets associates a set of secrets with the given key. +// Calling again with the same key replaces the previous secrets, +// so rotating credentials (e.g. access tokens) are handled naturally +// without explicit unregistration. +// +// The key should identify the credential scope, e.g. "qq-token:" +// or "telegram:". For multiple instances of the same adapter, +// include a stable instance identifier in the key. +// +// This is intentionally scoped to IM error rendering only: logs and +// normal outbound messages keep their original text so operators can +// debug issues and user content is not mutated. +func SetIMErrorSecrets(key string, secrets ...string) { + var variants []string + for _, secret := range secrets { + variants = append(variants, imErrorRedactionVariants(secret)...) + } + + imErrorRedactionRegistry.mu.Lock() + defer imErrorRedactionRegistry.mu.Unlock() + + if len(variants) == 0 { + if _, exists := imErrorRedactionRegistry.groups[key]; !exists { + return + } + delete(imErrorRedactionRegistry.groups, key) + } else { + if slices.Equal(imErrorRedactionRegistry.groups[key], variants) { + return + } + imErrorRedactionRegistry.groups[key] = variants + } + imErrorRedactionRegistry.cache = rebuildSecretCache(imErrorRedactionRegistry.groups) +} + +// RedactIMErrorText masks registered secrets from error text that is about to +// be rendered back into an IM conversation. +func RedactIMErrorText(text string) string { + if strings.TrimSpace(text) == "" { + return text + } + + imErrorRedactionRegistry.mu.RLock() + cache := imErrorRedactionRegistry.cache + imErrorRedactionRegistry.mu.RUnlock() + + result := text + for _, secret := range cache { + result = strings.ReplaceAll(result, secret, strings.Repeat("*", utf8.RuneCountInString(secret))) + } + return result +} + +func imErrorRedactionVariants(secret string) []string { + secret = strings.TrimSpace(secret) + if secret == "" { + return nil + } + + variants := []string{secret} + runes := []rune(secret) + half := len(runes) / 2 + if half > 5 { + variants = append(variants, string(runes[:half]), string(runes[len(runes)-half:])) + } + if encoded := url.QueryEscape(secret); encoded != secret { + variants = append(variants, encoded) + } + return variants +} + +func rebuildSecretCache(groups map[string][]string) []string { + seen := make(map[string]struct{}) + var all []string + for _, variants := range groups { + for _, v := range variants { + if _, ok := seen[v]; ok { + continue + } + seen[v] = struct{}{} + all = append(all, v) + } + } + sort.Slice(all, func(i, j int) bool { + li := utf8.RuneCountInString(all[i]) + lj := utf8.RuneCountInString(all[j]) + if li == lj { + return all[i] < all[j] + } + return li > lj + }) + return all +} + +func resetIMErrorSecretsForTest() { + imErrorRedactionRegistry.mu.Lock() + defer imErrorRedactionRegistry.mu.Unlock() + imErrorRedactionRegistry.groups = map[string][]string{} + imErrorRedactionRegistry.cache = nil +} + +// ResetIMErrorSecretsForTest clears the IM error redaction registry. +// It is intended for tests in other packages that need deterministic state. +func ResetIMErrorSecretsForTest() { + resetIMErrorSecretsForTest() +} diff --git a/internal/channel/error_redaction_test.go b/internal/channel/error_redaction_test.go new file mode 100644 index 00000000..cd95d7f1 --- /dev/null +++ b/internal/channel/error_redaction_test.go @@ -0,0 +1,126 @@ +package channel + +import ( + "net/url" + "strings" + "testing" + "unicode/utf8" +) + +func TestRedactIMErrorText_RedactsFullSecretAndBothHalves(t *testing.T) { + resetIMErrorSecretsForTest() + t.Cleanup(resetIMErrorSecretsForTest) + + const secret = "123456:ABCDEFGHIJKLMNOPQRSTUVWXYZ" + SetIMErrorSecrets("test", secret) + + runes := []rune(secret) + half := len(runes) / 2 + prefixHalf := string(runes[:half]) + suffixHalf := string(runes[len(runes)-half:]) + + input := strings.Join([]string{ + "full=" + secret, + "prefix=" + prefixHalf, + "suffix=" + suffixHalf, + }, " ") + + got := RedactIMErrorText(input) + if strings.Contains(got, secret) { + t.Fatalf("full secret should be redacted: %q", got) + } + if strings.Contains(got, prefixHalf) { + t.Fatalf("prefix half should be redacted: %q", got) + } + if strings.Contains(got, suffixHalf) { + t.Fatalf("suffix half should be redacted: %q", got) + } + if !strings.Contains(got, strings.Repeat("*", utf8.RuneCountInString(secret))) { + t.Fatalf("full secret mask missing: %q", got) + } +} + +func TestRedactIMErrorText_DoesNotRegisterShortHalfFragments(t *testing.T) { + resetIMErrorSecretsForTest() + t.Cleanup(resetIMErrorSecretsForTest) + + const secret = "ABCDEFGHIJ" + SetIMErrorSecrets("test", secret) + + runes := []rune(secret) + shortHalf := string(runes[:len(runes)/2]) + + got := RedactIMErrorText("partial=" + shortHalf) + if got != "partial="+shortHalf { + t.Fatalf("short half fragment should not be redacted: %q", got) + } + + got = RedactIMErrorText("full=" + secret) + if strings.Contains(got, secret) { + t.Fatalf("full secret should still be redacted: %q", got) + } +} + +func TestRedactIMErrorText_RedactsURLEncodedVariant(t *testing.T) { + resetIMErrorSecretsForTest() + t.Cleanup(resetIMErrorSecretsForTest) + + const secret = "123456:ABC+DEF/GHI=JKL" + SetIMErrorSecrets("test", secret) + + encoded := url.QueryEscape(secret) + if encoded == secret { + t.Fatal("test secret must differ when URL-encoded") + } + + got := RedactIMErrorText("url=" + encoded) + if strings.Contains(got, encoded) { + t.Fatalf("URL-encoded secret should be redacted: %q", got) + } +} + +func TestSetIMErrorSecrets_ReplacesOnSameKey(t *testing.T) { + resetIMErrorSecretsForTest() + t.Cleanup(resetIMErrorSecretsForTest) + + const oldToken = "old-rotating-token-ABCDEFGHIJKLMNO" + const newToken = "new-rotating-token-XYZXYZXYZXYZXYZ" + + SetIMErrorSecrets("qq-token:app1", oldToken) + + got := RedactIMErrorText("err: " + oldToken) + if strings.Contains(got, oldToken) { + t.Fatalf("old token should be redacted: %q", got) + } + + // Simulate token rotation: same key, new value + SetIMErrorSecrets("qq-token:app1", newToken) + + got = RedactIMErrorText("err: " + oldToken) + if !strings.Contains(got, oldToken) { + t.Fatalf("old token should no longer be redacted after replacement: %q", got) + } + got = RedactIMErrorText("err: " + newToken) + if strings.Contains(got, newToken) { + t.Fatalf("new token should be redacted: %q", got) + } +} + +func TestSetIMErrorSecrets_IndependentKeys(t *testing.T) { + resetIMErrorSecretsForTest() + t.Cleanup(resetIMErrorSecretsForTest) + + const secretA = "secret-AAAAAAAAAAAAAAAA" + const secretB = "secret-BBBBBBBBBBBBBBBB" + + SetIMErrorSecrets("key-a", secretA) + SetIMErrorSecrets("key-b", secretB) + + // Replacing key-a should not affect key-b + SetIMErrorSecrets("key-a", "secret-CCCCCCCCCCCCCCCC") + + got := RedactIMErrorText("err: " + secretB) + if strings.Contains(got, secretB) { + t.Fatalf("secretB should still be redacted: %q", got) + } +} diff --git a/internal/mcp/providers/web/provider.go b/internal/mcp/providers/web/provider.go index 8bf400e9..ea4ce4aa 100644 --- a/internal/mcp/providers/web/provider.go +++ b/internal/mcp/providers/web/provider.go @@ -21,6 +21,8 @@ import ( "strings" "time" + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/db/sqlc" mcpgw "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/searchproviders" "github.com/memohai/memoh/internal/settings" @@ -87,6 +89,7 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } + registerSearchProviderSecrets(provider) switch toolName { case toolWebSearch: @@ -1145,3 +1148,19 @@ func firstNonEmpty(values ...string) string { } return "" } + +// searchProviderSecretFields are config keys known to hold credentials. +var searchProviderSecretFields = []string{"api_key", "secret_id", "secret_key"} + +func registerSearchProviderSecrets(provider sqlc.SearchProvider) { + cfg := parseConfig(provider.Config) + var secrets []string + for _, key := range searchProviderSecretFields { + if v := stringValue(cfg[key]); v != "" { + secrets = append(secrets, v) + } + } + if len(secrets) > 0 { + channel.SetIMErrorSecrets("search:"+provider.ID.String(), secrets...) + } +} diff --git a/internal/models/models.go b/internal/models/models.go index 59152398..388cbbda 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -487,7 +488,14 @@ func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID st if err != nil { return sqlc.LlmProvider{}, err } - return queries.GetLlmProviderByID(ctx, parsed) + provider, err := queries.GetLlmProviderByID(ctx, parsed) + if err != nil { + return sqlc.LlmProvider{}, err + } + if strings.TrimSpace(provider.ApiKey) != "" { + channel.SetIMErrorSecrets("llm-provider:"+providerID, provider.ApiKey) + } + return provider, nil } func intToInt4(value int, name string) (pgtype.Int4, error) {