mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(channel): redact credentials from IM error messages (#240)
This commit is contained in:
@@ -108,6 +108,7 @@ func (*DiscordAdapter) Descriptor() channel.Descriptor {
|
||||
}
|
||||
|
||||
func (a *DiscordAdapter) getOrCreateSession(token, configID string) (*discordgo.Session, error) {
|
||||
channel.SetIMErrorSecrets("discord:"+configID, token)
|
||||
a.mu.RLock()
|
||||
session, ok := a.sessions[token]
|
||||
a.mu.RUnlock()
|
||||
|
||||
@@ -76,7 +76,7 @@ func (s *discordOutboundStream) Push(ctx context.Context, event channel.StreamEv
|
||||
return nil
|
||||
|
||||
case channel.StreamEventError:
|
||||
errText := strings.TrimSpace(event.Error)
|
||||
errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error))
|
||||
if errText == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) }
|
||||
|
||||
func TestDiscordOutboundStream_PushErrorEventRedactsSecrets(t *testing.T) {
|
||||
channel.ResetIMErrorSecretsForTest()
|
||||
t.Cleanup(channel.ResetIMErrorSecretsForTest)
|
||||
|
||||
const token = "discord-token-ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
channel.SetIMErrorSecrets("test", token)
|
||||
prefixHalf := token[:len(token)/2]
|
||||
|
||||
var sentBody string
|
||||
session, err := discordgo.New("Bot test")
|
||||
if err != nil {
|
||||
t.Fatalf("create session: %v", err)
|
||||
}
|
||||
session.Client = &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
sentBody = string(body)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"msg-1","channel_id":"ch-1"}`)),
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
}
|
||||
return resp, nil
|
||||
}),
|
||||
}
|
||||
|
||||
stream := &discordOutboundStream{
|
||||
adapter: &DiscordAdapter{},
|
||||
target: "ch-1",
|
||||
session: session,
|
||||
}
|
||||
|
||||
err = stream.Push(context.Background(), channel.StreamEvent{
|
||||
Type: channel.StreamEventError,
|
||||
Error: "request failed: " + prefixHalf,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("push error event: %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(sentBody), &payload); err != nil {
|
||||
t.Fatalf("decode sent body: %v (body=%q)", err, sentBody)
|
||||
}
|
||||
content, _ := payload["content"].(string)
|
||||
if strings.Contains(content, prefixHalf) {
|
||||
t.Fatalf("expected prefix half to be redacted, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "Error: ") {
|
||||
t.Fatalf("expected error prefix, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, strings.Repeat("*", len(prefixHalf))) {
|
||||
t.Fatalf("expected redaction mask, got %q", content)
|
||||
}
|
||||
}
|
||||
@@ -192,3 +192,12 @@ func (c Config) openBaseURL() string {
|
||||
}
|
||||
return lark.FeishuBaseUrl
|
||||
}
|
||||
|
||||
func (c Config) registerIMErrorSecrets() {
|
||||
channel.SetIMErrorSecrets("feishu:"+c.AppID, c.AppSecret, c.EncryptKey, c.VerificationToken)
|
||||
}
|
||||
|
||||
func (c Config) newClient() *lark.Client {
|
||||
c.registerIMErrorSecrets()
|
||||
return lark.NewClient(c.AppID, c.AppSecret, lark.WithOpenBaseUrl(c.openBaseURL()))
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ func (*FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
pageSize := directoryLimit(query.Limit)
|
||||
req := larkcontact.NewListUserReqBuilder().
|
||||
UserIdType(larkcontact.UserIdTypeOpenId).
|
||||
@@ -66,7 +66,7 @@ func (*FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
pageSize := directoryLimit(query.Limit)
|
||||
var items []*larkim.ListChat
|
||||
if strings.TrimSpace(query.Query) != "" {
|
||||
@@ -115,7 +115,7 @@ func (*FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelC
|
||||
if chatID == "" {
|
||||
return nil, errors.New("feishu list group members: empty group id")
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
pageSize := directoryLimit(query.Limit)
|
||||
req := larkim.NewGetChatMembersReqBuilder().
|
||||
ChatId(chatID).
|
||||
@@ -146,7 +146,7 @@ func (a *FeishuAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelCon
|
||||
if err != nil {
|
||||
return channel.DirectoryEntry{}, err
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
input = strings.TrimSpace(input)
|
||||
switch kind {
|
||||
case channel.DirectoryEntryUser:
|
||||
|
||||
@@ -227,7 +227,7 @@ func (*FeishuAdapter) processingReactionGateway(cfg channel.ChannelConfig) (proc
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
gateway := &larkProcessingReactionGateway{api: client.Im.MessageReaction}
|
||||
return gateway, nil
|
||||
}
|
||||
@@ -291,7 +291,7 @@ func (*FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string]a
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
client := lark.NewClient(cfg.AppID, cfg.AppSecret, lark.WithOpenBaseUrl(cfg.openBaseURL()))
|
||||
client := cfg.newClient()
|
||||
resp, err := client.Get(ctx, "/open-apis/bot/v3/info", nil, larkcore.AccessTokenTypeTenant)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("feishu discover self: %w", err)
|
||||
@@ -466,6 +466,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig,
|
||||
eventDispatcher.OnP2MessageReactionDeletedV1(func(_ context.Context, _ *larkim.P2MessageReactionDeletedV1) error {
|
||||
return nil
|
||||
})
|
||||
feishuCfg.registerIMErrorSecrets()
|
||||
return larkws.NewClient(
|
||||
feishuCfg.AppID,
|
||||
feishuCfg.AppSecret,
|
||||
@@ -526,7 +527,7 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg
|
||||
return err
|
||||
}
|
||||
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
|
||||
if len(msg.Message.Attachments) > 0 {
|
||||
for _, att := range msg.Message.Attachments {
|
||||
@@ -603,7 +604,7 @@ func (a *FeishuAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -855,7 +856,7 @@ func (*FeishuAdapter) ResolveAttachment(ctx context.Context, cfg channel.Channel
|
||||
if err != nil {
|
||||
return channel.AttachmentPayload{}, err
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
|
||||
resourceType := "file"
|
||||
if isFeishuImageAttachment(attachment) {
|
||||
|
||||
@@ -75,7 +75,7 @@ func (*FeishuAdapter) lookupSenderProfile(ctx context.Context, cfg channel.Chann
|
||||
if err != nil {
|
||||
return feishuSenderProfile{}, err
|
||||
}
|
||||
client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL()))
|
||||
client := feishuCfg.newClient()
|
||||
|
||||
var lastErr error
|
||||
chatID = strings.TrimSpace(chatID)
|
||||
|
||||
@@ -133,7 +133,7 @@ func (s *feishuOutboundStream) Push(ctx context.Context, event channel.StreamEve
|
||||
}
|
||||
return nil
|
||||
case channel.StreamEventError:
|
||||
errText := strings.TrimSpace(event.Error)
|
||||
errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error))
|
||||
if errText == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
type qqClient struct {
|
||||
@@ -105,6 +107,7 @@ func (c *qqClient) accessToken(ctx context.Context) (string, error) {
|
||||
}
|
||||
c.token = token
|
||||
c.expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
channel.SetIMErrorSecrets("qq-token:"+c.appID, c.clientSecret, c.token)
|
||||
return c.token, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ func (*QQAdapter) ProcessingFailed(context.Context, channel.ChannelConfig, chann
|
||||
}
|
||||
|
||||
func (a *QQAdapter) getOrCreateClient(cfg channel.ChannelConfig, parsed Config) *qqClient {
|
||||
channel.SetIMErrorSecrets("qq:"+parsed.AppID, parsed.AppSecret)
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package qq
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -23,6 +24,11 @@ type qqOutboundStream struct {
|
||||
}
|
||||
|
||||
func (a *QQAdapter) OpenStream(_ context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) {
|
||||
parsed, err := parseConfig(cfg.Credentials)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qq open stream: %w", err)
|
||||
}
|
||||
channel.SetIMErrorSecrets("qq:"+parsed.AppID, parsed.AppSecret)
|
||||
return &qqOutboundStream{
|
||||
target: target,
|
||||
reply: opts.Reply,
|
||||
@@ -80,7 +86,7 @@ func (s *qqOutboundStream) Push(ctx context.Context, event channel.StreamEvent)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
case channel.StreamEventError:
|
||||
errText := strings.TrimSpace(event.Error)
|
||||
errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error))
|
||||
if errText == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package qq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
@@ -169,3 +170,32 @@ func TestQQOutboundStreamRejectsAfterClose(t *testing.T) {
|
||||
t.Fatal("expected closed error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQQOutboundStreamErrorRedactsRegisteredTokenFragments(t *testing.T) {
|
||||
channel.ResetIMErrorSecretsForTest()
|
||||
t.Cleanup(channel.ResetIMErrorSecretsForTest)
|
||||
|
||||
const token = "qq-token-ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
channel.SetIMErrorSecrets("test", token)
|
||||
prefixHalf := token[:len(token)/2]
|
||||
|
||||
var sent []channel.OutboundMessage
|
||||
stream := &qqOutboundStream{
|
||||
target: "c2c:user-openid",
|
||||
send: func(_ context.Context, msg channel.OutboundMessage) error {
|
||||
sent = append(sent, msg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
err := stream.Push(context.Background(), channel.StreamEvent{Type: channel.StreamEventError, Error: "failed: " + prefixHalf})
|
||||
if err != nil {
|
||||
t.Fatalf("push error: %v", err)
|
||||
}
|
||||
if len(sent) != 1 {
|
||||
t.Fatalf("expected one outbound message, got %d", len(sent))
|
||||
}
|
||||
if got := sent[0].Message.PlainText(); strings.Contains(got, prefixHalf) {
|
||||
t.Fatalf("expected redacted token fragment, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -439,7 +439,7 @@ func (s *telegramOutboundStream) pushFinal(ctx context.Context, event channel.St
|
||||
}
|
||||
|
||||
func (s *telegramOutboundStream) pushError(ctx context.Context, event channel.StreamEvent) error {
|
||||
errText := strings.TrimSpace(event.Error)
|
||||
errText := channel.RedactIMErrorText(strings.TrimSpace(event.Error))
|
||||
if errText == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -105,6 +105,56 @@ func TestTelegramOutboundStream_PushErrorEventEmptyNoOp(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTelegramOutboundStream_PushErrorEventRedactsRegisteredTokenFragments(t *testing.T) {
|
||||
channel.ResetIMErrorSecretsForTest()
|
||||
t.Cleanup(channel.ResetIMErrorSecretsForTest)
|
||||
|
||||
const botToken = "123456:ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
var sentText string
|
||||
|
||||
adapter := NewTelegramAdapter(nil)
|
||||
stream, err := adapter.OpenStream(context.Background(), channel.ChannelConfig{
|
||||
ID: "cfg-1",
|
||||
Credentials: map[string]any{"botToken": botToken},
|
||||
}, "12345", channel.StreamOptions{Metadata: map[string]any{"conversation_type": "private"}})
|
||||
if err != nil {
|
||||
t.Fatalf("open stream: %v", err)
|
||||
}
|
||||
s, ok := stream.(*telegramOutboundStream)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected stream type %T", stream)
|
||||
}
|
||||
|
||||
origGetBot := getOrCreateBotForTest
|
||||
origSendText := sendTextForTest
|
||||
getOrCreateBotForTest = func(_ *TelegramAdapter, _, _ string) (*tgbotapi.BotAPI, error) {
|
||||
return &tgbotapi.BotAPI{Token: botToken}, nil
|
||||
}
|
||||
sendTextForTest = func(_ *tgbotapi.BotAPI, _ string, text string, _ int, _ string) (int64, int, error) {
|
||||
sentText = text
|
||||
return 1, 1, nil
|
||||
}
|
||||
defer func() {
|
||||
getOrCreateBotForTest = origGetBot
|
||||
sendTextForTest = origSendText
|
||||
}()
|
||||
|
||||
prefixHalf := botToken[:len(botToken)/2]
|
||||
err = s.Push(context.Background(), channel.StreamEvent{Type: channel.StreamEventError, Error: "request failed: " + prefixHalf})
|
||||
if err != nil {
|
||||
t.Fatalf("push error event: %v", err)
|
||||
}
|
||||
if strings.Contains(sentText, prefixHalf) {
|
||||
t.Fatalf("expected prefix half to be redacted, got %q", sentText)
|
||||
}
|
||||
if !strings.Contains(sentText, "Error: ") {
|
||||
t.Fatalf("expected error prefix, got %q", sentText)
|
||||
}
|
||||
if !strings.Contains(sentText, strings.Repeat("*", len(prefixHalf))) {
|
||||
t.Fatalf("expected redaction mask, got %q", sentText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTelegramOutboundStream_CloseContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ func (a *TelegramAdapter) SetAssetOpener(opener assetOpener) {
|
||||
var getOrCreateBotForTest func(a *TelegramAdapter, token, configID string) (*tgbotapi.BotAPI, error)
|
||||
|
||||
func (a *TelegramAdapter) getOrCreateBot(cfg Config, configID string) (*tgbotapi.BotAPI, error) {
|
||||
channel.SetIMErrorSecrets("telegram:"+configID, cfg.BotToken)
|
||||
if getOrCreateBotForTest != nil {
|
||||
return getOrCreateBotForTest(a, cfg.BotToken, configID)
|
||||
}
|
||||
@@ -644,6 +645,11 @@ func (a *TelegramAdapter) OpenStream(ctx context.Context, cfg channel.ChannelCon
|
||||
if target == "" {
|
||||
return nil, errors.New("telegram target is required")
|
||||
}
|
||||
telegramCfg, err := parseConfig(cfg.Credentials)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("telegram open stream: %w", err)
|
||||
}
|
||||
channel.SetIMErrorSecrets("telegram:"+cfg.ID, telegramCfg.BotToken)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var imErrorRedactionRegistry = struct {
|
||||
mu sync.RWMutex
|
||||
groups map[string][]string // key → variants
|
||||
cache []string // deduplicated, sorted longest-first
|
||||
}{
|
||||
groups: map[string][]string{},
|
||||
}
|
||||
|
||||
// SetIMErrorSecrets associates a set of secrets with the given key.
|
||||
// Calling again with the same key replaces the previous secrets,
|
||||
// so rotating credentials (e.g. access tokens) are handled naturally
|
||||
// without explicit unregistration.
|
||||
//
|
||||
// The key should identify the credential scope, e.g. "qq-token:<appID>"
|
||||
// or "telegram:<configID>". For multiple instances of the same adapter,
|
||||
// include a stable instance identifier in the key.
|
||||
//
|
||||
// This is intentionally scoped to IM error rendering only: logs and
|
||||
// normal outbound messages keep their original text so operators can
|
||||
// debug issues and user content is not mutated.
|
||||
func SetIMErrorSecrets(key string, secrets ...string) {
|
||||
var variants []string
|
||||
for _, secret := range secrets {
|
||||
variants = append(variants, imErrorRedactionVariants(secret)...)
|
||||
}
|
||||
|
||||
imErrorRedactionRegistry.mu.Lock()
|
||||
defer imErrorRedactionRegistry.mu.Unlock()
|
||||
|
||||
if len(variants) == 0 {
|
||||
if _, exists := imErrorRedactionRegistry.groups[key]; !exists {
|
||||
return
|
||||
}
|
||||
delete(imErrorRedactionRegistry.groups, key)
|
||||
} else {
|
||||
if slices.Equal(imErrorRedactionRegistry.groups[key], variants) {
|
||||
return
|
||||
}
|
||||
imErrorRedactionRegistry.groups[key] = variants
|
||||
}
|
||||
imErrorRedactionRegistry.cache = rebuildSecretCache(imErrorRedactionRegistry.groups)
|
||||
}
|
||||
|
||||
// RedactIMErrorText masks registered secrets from error text that is about to
|
||||
// be rendered back into an IM conversation.
|
||||
func RedactIMErrorText(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
imErrorRedactionRegistry.mu.RLock()
|
||||
cache := imErrorRedactionRegistry.cache
|
||||
imErrorRedactionRegistry.mu.RUnlock()
|
||||
|
||||
result := text
|
||||
for _, secret := range cache {
|
||||
result = strings.ReplaceAll(result, secret, strings.Repeat("*", utf8.RuneCountInString(secret)))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func imErrorRedactionVariants(secret string) []string {
|
||||
secret = strings.TrimSpace(secret)
|
||||
if secret == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
variants := []string{secret}
|
||||
runes := []rune(secret)
|
||||
half := len(runes) / 2
|
||||
if half > 5 {
|
||||
variants = append(variants, string(runes[:half]), string(runes[len(runes)-half:]))
|
||||
}
|
||||
if encoded := url.QueryEscape(secret); encoded != secret {
|
||||
variants = append(variants, encoded)
|
||||
}
|
||||
return variants
|
||||
}
|
||||
|
||||
func rebuildSecretCache(groups map[string][]string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var all []string
|
||||
for _, variants := range groups {
|
||||
for _, v := range variants {
|
||||
if _, ok := seen[v]; ok {
|
||||
continue
|
||||
}
|
||||
seen[v] = struct{}{}
|
||||
all = append(all, v)
|
||||
}
|
||||
}
|
||||
sort.Slice(all, func(i, j int) bool {
|
||||
li := utf8.RuneCountInString(all[i])
|
||||
lj := utf8.RuneCountInString(all[j])
|
||||
if li == lj {
|
||||
return all[i] < all[j]
|
||||
}
|
||||
return li > lj
|
||||
})
|
||||
return all
|
||||
}
|
||||
|
||||
func resetIMErrorSecretsForTest() {
|
||||
imErrorRedactionRegistry.mu.Lock()
|
||||
defer imErrorRedactionRegistry.mu.Unlock()
|
||||
imErrorRedactionRegistry.groups = map[string][]string{}
|
||||
imErrorRedactionRegistry.cache = nil
|
||||
}
|
||||
|
||||
// ResetIMErrorSecretsForTest clears the IM error redaction registry.
|
||||
// It is intended for tests in other packages that need deterministic state.
|
||||
func ResetIMErrorSecretsForTest() {
|
||||
resetIMErrorSecretsForTest()
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestRedactIMErrorText_RedactsFullSecretAndBothHalves(t *testing.T) {
|
||||
resetIMErrorSecretsForTest()
|
||||
t.Cleanup(resetIMErrorSecretsForTest)
|
||||
|
||||
const secret = "123456:ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
SetIMErrorSecrets("test", secret)
|
||||
|
||||
runes := []rune(secret)
|
||||
half := len(runes) / 2
|
||||
prefixHalf := string(runes[:half])
|
||||
suffixHalf := string(runes[len(runes)-half:])
|
||||
|
||||
input := strings.Join([]string{
|
||||
"full=" + secret,
|
||||
"prefix=" + prefixHalf,
|
||||
"suffix=" + suffixHalf,
|
||||
}, " ")
|
||||
|
||||
got := RedactIMErrorText(input)
|
||||
if strings.Contains(got, secret) {
|
||||
t.Fatalf("full secret should be redacted: %q", got)
|
||||
}
|
||||
if strings.Contains(got, prefixHalf) {
|
||||
t.Fatalf("prefix half should be redacted: %q", got)
|
||||
}
|
||||
if strings.Contains(got, suffixHalf) {
|
||||
t.Fatalf("suffix half should be redacted: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, strings.Repeat("*", utf8.RuneCountInString(secret))) {
|
||||
t.Fatalf("full secret mask missing: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactIMErrorText_DoesNotRegisterShortHalfFragments(t *testing.T) {
|
||||
resetIMErrorSecretsForTest()
|
||||
t.Cleanup(resetIMErrorSecretsForTest)
|
||||
|
||||
const secret = "ABCDEFGHIJ"
|
||||
SetIMErrorSecrets("test", secret)
|
||||
|
||||
runes := []rune(secret)
|
||||
shortHalf := string(runes[:len(runes)/2])
|
||||
|
||||
got := RedactIMErrorText("partial=" + shortHalf)
|
||||
if got != "partial="+shortHalf {
|
||||
t.Fatalf("short half fragment should not be redacted: %q", got)
|
||||
}
|
||||
|
||||
got = RedactIMErrorText("full=" + secret)
|
||||
if strings.Contains(got, secret) {
|
||||
t.Fatalf("full secret should still be redacted: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactIMErrorText_RedactsURLEncodedVariant(t *testing.T) {
|
||||
resetIMErrorSecretsForTest()
|
||||
t.Cleanup(resetIMErrorSecretsForTest)
|
||||
|
||||
const secret = "123456:ABC+DEF/GHI=JKL"
|
||||
SetIMErrorSecrets("test", secret)
|
||||
|
||||
encoded := url.QueryEscape(secret)
|
||||
if encoded == secret {
|
||||
t.Fatal("test secret must differ when URL-encoded")
|
||||
}
|
||||
|
||||
got := RedactIMErrorText("url=" + encoded)
|
||||
if strings.Contains(got, encoded) {
|
||||
t.Fatalf("URL-encoded secret should be redacted: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIMErrorSecrets_ReplacesOnSameKey(t *testing.T) {
|
||||
resetIMErrorSecretsForTest()
|
||||
t.Cleanup(resetIMErrorSecretsForTest)
|
||||
|
||||
const oldToken = "old-rotating-token-ABCDEFGHIJKLMNO"
|
||||
const newToken = "new-rotating-token-XYZXYZXYZXYZXYZ"
|
||||
|
||||
SetIMErrorSecrets("qq-token:app1", oldToken)
|
||||
|
||||
got := RedactIMErrorText("err: " + oldToken)
|
||||
if strings.Contains(got, oldToken) {
|
||||
t.Fatalf("old token should be redacted: %q", got)
|
||||
}
|
||||
|
||||
// Simulate token rotation: same key, new value
|
||||
SetIMErrorSecrets("qq-token:app1", newToken)
|
||||
|
||||
got = RedactIMErrorText("err: " + oldToken)
|
||||
if !strings.Contains(got, oldToken) {
|
||||
t.Fatalf("old token should no longer be redacted after replacement: %q", got)
|
||||
}
|
||||
got = RedactIMErrorText("err: " + newToken)
|
||||
if strings.Contains(got, newToken) {
|
||||
t.Fatalf("new token should be redacted: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIMErrorSecrets_IndependentKeys(t *testing.T) {
|
||||
resetIMErrorSecretsForTest()
|
||||
t.Cleanup(resetIMErrorSecretsForTest)
|
||||
|
||||
const secretA = "secret-AAAAAAAAAAAAAAAA"
|
||||
const secretB = "secret-BBBBBBBBBBBBBBBB"
|
||||
|
||||
SetIMErrorSecrets("key-a", secretA)
|
||||
SetIMErrorSecrets("key-b", secretB)
|
||||
|
||||
// Replacing key-a should not affect key-b
|
||||
SetIMErrorSecrets("key-a", "secret-CCCCCCCCCCCCCCCC")
|
||||
|
||||
got := RedactIMErrorText("err: " + secretB)
|
||||
if strings.Contains(got, secretB) {
|
||||
t.Fatalf("secretB should still be redacted: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
mcpgw "github.com/memohai/memoh/internal/mcp"
|
||||
"github.com/memohai/memoh/internal/searchproviders"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
@@ -87,6 +89,7 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
|
||||
if err != nil {
|
||||
return mcpgw.BuildToolErrorResult(err.Error()), nil
|
||||
}
|
||||
registerSearchProviderSecrets(provider)
|
||||
|
||||
switch toolName {
|
||||
case toolWebSearch:
|
||||
@@ -1145,3 +1148,19 @@ func firstNonEmpty(values ...string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// searchProviderSecretFields are config keys known to hold credentials.
|
||||
var searchProviderSecretFields = []string{"api_key", "secret_id", "secret_key"}
|
||||
|
||||
func registerSearchProviderSecrets(provider sqlc.SearchProvider) {
|
||||
cfg := parseConfig(provider.Config)
|
||||
var secrets []string
|
||||
for _, key := range searchProviderSecretFields {
|
||||
if v := stringValue(cfg[key]); v != "" {
|
||||
secrets = append(secrets, v)
|
||||
}
|
||||
}
|
||||
if len(secrets) > 0 {
|
||||
channel.SetIMErrorSecrets("search:"+provider.ID.String(), secrets...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
@@ -487,7 +488,14 @@ func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID st
|
||||
if err != nil {
|
||||
return sqlc.LlmProvider{}, err
|
||||
}
|
||||
return queries.GetLlmProviderByID(ctx, parsed)
|
||||
provider, err := queries.GetLlmProviderByID(ctx, parsed)
|
||||
if err != nil {
|
||||
return sqlc.LlmProvider{}, err
|
||||
}
|
||||
if strings.TrimSpace(provider.ApiKey) != "" {
|
||||
channel.SetIMErrorSecrets("llm-provider:"+providerID, provider.ApiKey)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func intToInt4(value int, name string) (pgtype.Int4, error) {
|
||||
|
||||
Reference in New Issue
Block a user