From e6a6dbe3f6f35978e7e7c378f1b3259a2046bfa6 Mon Sep 17 00:00:00 2001 From: "Ringo.Typowriter" Date: Sat, 7 Mar 2026 17:12:06 +0800 Subject: [PATCH] feat(channel): add QQ channel support and image message pipeline (#199) * feat(channel): add qq adapter and outbound delivery * feat(channel): ingest inbound qq messages * feat(web): expose qq channel in management ui * feat(channel): support qq attachment ingestion * fix(mcp): fail read raw immediately for missing files * fix(agent): parse inline image data into native image parts * test(agent): align read_media tool tests with SDK options * fix(channel): harden qq image delivery and reconnect loop Avoid data URLs for qq channel images, reset reconnect backoff after healthy sessions, and fall back gracefully for malformed public image URLs. * fix(channel): restore qq media delivery and target resolution * fix(qq,mcp,agent): fix message/qq regressions and pass go lint * fix(qq,agent): validate inline base64 and sync heartbeat seq * fix(qq): validate remote voice mime for upload checks * fix(qq): fall back intents and restore adapter wiring * fix(qq): prevent final text leakage and dedupe persisted inbound query --- apps/web/src/i18n/locales/en.json | 2 + apps/web/src/i18n/locales/zh.json | 2 + .../pages/bots/components/bot-channels.vue | 2 + apps/web/src/pages/settings/index.vue | 2 +- apps/web/src/utils/channel-icons.ts | 1 + cmd/agent/main.go | 16 + internal/attachment/normalize.go | 3 + internal/attachment/normalize_test.go | 6 + internal/channel/adapters/qq/client.go | 397 +++++++++ internal/channel/adapters/qq/client_test.go | 45 + internal/channel/adapters/qq/config.go | 171 ++++ internal/channel/adapters/qq/config_test.go | 192 +++++ internal/channel/adapters/qq/descriptor.go | 5 + internal/channel/adapters/qq/face_tags.go | 31 + internal/channel/adapters/qq/factory.go | 18 + internal/channel/adapters/qq/qq.go | 264 ++++++ internal/channel/adapters/qq/receive.go | 628 ++++++++++++++ internal/channel/adapters/qq/receive_test.go | 330 ++++++++ internal/channel/adapters/qq/send.go | 358 ++++++++ internal/channel/adapters/qq/send_test.go | 770 ++++++++++++++++++ internal/channel/adapters/qq/stream.go | 152 ++++ internal/channel/adapters/qq/stream_test.go | 171 ++++ .../channel/adapters/qq/target_resolver.go | 136 ++++ .../adapters/qq/target_resolver_test.go | 201 +++++ internal/channel/inbound/channel.go | 9 +- internal/channel/inbound/channel_test.go | 169 +++- internal/channel/outbound.go | 37 +- internal/channel/outbound_test.go | 34 + internal/conversation/flow/resolver.go | 58 +- .../conversation/flow/resolver_dedupe_test.go | 45 + internal/mcp/mcpclient/client.go | 31 +- internal/mcp/mcpclient/client_test.go | 130 +++ internal/mcp/providers/message/provider.go | 33 +- .../mcp/providers/message/provider_test.go | 124 +++ internal/mcp/providers/skill/provider.go | 4 +- internal/mcp/providers/subagent/provider.go | 4 +- internal/mcp/providers/webfetch/provider.go | 8 +- packages/agent/src/agent.test.ts | 88 ++ packages/agent/src/agent.ts | 9 +- packages/agent/src/utils/image-parts.ts | 144 ++++ .../src/utils/read-media-injector.test.ts | 41 +- .../agent/src/utils/read-media-injector.ts | 18 +- 42 files changed, 4825 insertions(+), 64 deletions(-) create mode 100644 internal/channel/adapters/qq/client.go create mode 100644 internal/channel/adapters/qq/client_test.go create mode 100644 internal/channel/adapters/qq/config.go create mode 100644 internal/channel/adapters/qq/config_test.go create mode 100644 internal/channel/adapters/qq/descriptor.go create mode 100644 internal/channel/adapters/qq/face_tags.go create mode 100644 internal/channel/adapters/qq/factory.go create mode 100644 internal/channel/adapters/qq/qq.go create mode 100644 internal/channel/adapters/qq/receive.go create mode 100644 internal/channel/adapters/qq/receive_test.go create mode 100644 internal/channel/adapters/qq/send.go create mode 100644 internal/channel/adapters/qq/send_test.go create mode 100644 internal/channel/adapters/qq/stream.go create mode 100644 internal/channel/adapters/qq/stream_test.go create mode 100644 internal/channel/adapters/qq/target_resolver.go create mode 100644 internal/channel/adapters/qq/target_resolver_test.go create mode 100644 internal/conversation/flow/resolver_dedupe_test.go create mode 100644 internal/mcp/mcpclient/client_test.go create mode 100644 packages/agent/src/agent.test.ts create mode 100644 packages/agent/src/utils/image-parts.ts diff --git a/apps/web/src/i18n/locales/en.json b/apps/web/src/i18n/locales/en.json index c01894d3..c5f7815b 100644 --- a/apps/web/src/i18n/locales/en.json +++ b/apps/web/src/i18n/locales/en.json @@ -668,6 +668,7 @@ "types": { "feishu": "Feishu", "discord": "Discord", + "qq": "QQ", "telegram": "Telegram", "web": "Web", "local": "Local" @@ -675,6 +676,7 @@ "typesShort": { "feishu": "FS", "discord": "DC", + "qq": "QQ", "telegram": "TG", "web": "Web", "local": "CLI" diff --git a/apps/web/src/i18n/locales/zh.json b/apps/web/src/i18n/locales/zh.json index 37f49b6a..e2ae97d4 100644 --- a/apps/web/src/i18n/locales/zh.json +++ b/apps/web/src/i18n/locales/zh.json @@ -664,6 +664,7 @@ "types": { "feishu": "飞书", "discord": "Discord", + "qq": "QQ", "telegram": "Telegram", "web": "Web", "local": "本地" @@ -671,6 +672,7 @@ "typesShort": { "feishu": "飞", "discord": "DC", + "qq": "QQ", "telegram": "TG", "web": "Web", "local": "CLI" diff --git a/apps/web/src/pages/bots/components/bot-channels.vue b/apps/web/src/pages/bots/components/bot-channels.vue index 878252d2..5bc6b190 100644 --- a/apps/web/src/pages/bots/components/bot-channels.vue +++ b/apps/web/src/pages/bots/components/bot-channels.vue @@ -212,6 +212,7 @@ function addChannel(type: string) { function channelIcon(type: string): string { const icons: Record = { + qq: 'QQ', telegram: 'TG', feishu: '飞', } @@ -220,6 +221,7 @@ function channelIcon(type: string): string { function channelBadgeClass(type: string): string { const classes: Record = { + qq: 'bg-sky-100 text-sky-700 dark:bg-sky-900 dark:text-sky-300', telegram: 'bg-blue-100 text-blue-700 dark:bg-blue-900 dark:text-blue-300', feishu: 'bg-indigo-100 text-indigo-700 dark:bg-indigo-900 dark:text-indigo-300', } diff --git a/apps/web/src/pages/settings/index.vue b/apps/web/src/pages/settings/index.vue index c0aa4ea3..fe1431b3 100644 --- a/apps/web/src/pages/settings/index.vue +++ b/apps/web/src/pages/settings/index.vue @@ -303,7 +303,7 @@ function platformLabel(platformKey: string): string { } const platformOptions = computed(() => { - const options = new Set(['telegram', 'feishu', 'discord']) + const options = new Set(['telegram', 'feishu', 'discord', 'qq']) for (const identity of identities.value) { const platform = identity.channel.trim() if (platform) { diff --git a/apps/web/src/utils/channel-icons.ts b/apps/web/src/utils/channel-icons.ts index bf90cfbc..bef0cc9d 100644 --- a/apps/web/src/utils/channel-icons.ts +++ b/apps/web/src/utils/channel-icons.ts @@ -10,6 +10,7 @@ const LOCAL_CHANNEL_IMAGES: Record = { } const CHANNEL_ICONS: Record = { + qq: ['fab', 'qq'], telegram: ['fab', 'telegram'], feishu: ['fas', 'comment-dots'], web: ['fas', 'globe'], diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 7cfc6cd6..58695b54 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -28,6 +28,7 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/discord" "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" + "github.com/memohai/memoh/internal/channel/adapters/qq" "github.com/memohai/memoh/internal/channel/adapters/telegram" "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/channel/inbound" @@ -391,6 +392,10 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService discordAdapter.SetAssetOpener(mediaService) registry.MustRegister(discordAdapter) + qqAdapter := qq.NewQQAdapter(log) + qqAdapter.SetAssetOpener(mediaService) + registry.MustRegister(qqAdapter) + registry.MustRegister(feishu.NewFeishuAdapter(log)) registry.MustRegister(local.NewCLIAdapter(hub)) registry.MustRegister(local.NewWebAdapter(hub)) @@ -413,6 +418,17 @@ func provideChannelRouter( inboxService *inbox.Service, rc *boot.RuntimeConfig, ) *inbound.ChannelInboundProcessor { + adapter, ok := registry.Get(qq.Type) + if !ok { + panic("qq adapter not registered") + } + qqAdapter, ok := adapter.(*qq.QQAdapter) + if !ok { + panic("qq adapter has unexpected type") + } + qqAdapter.SetChannelIdentityResolver(identityService) + qqAdapter.SetRouteResolver(routeService) + processor := inbound.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, botService, policyService, preauthService, bindService, rc.JwtSecret, 5*time.Minute) processor.SetMediaService(mediaService) processor.SetStreamObserver(local.NewRouteHubBroadcaster(hub)) diff --git a/internal/attachment/normalize.go b/internal/attachment/normalize.go index d20e822b..ebf28731 100644 --- a/internal/attachment/normalize.go +++ b/internal/attachment/normalize.go @@ -35,6 +35,9 @@ func NormalizeMime(raw string) string { if idx := strings.Index(mime, ";"); idx >= 0 { mime = strings.TrimSpace(mime[:idx]) } + if !strings.Contains(mime, "/") { + return "" + } return mime } diff --git a/internal/attachment/normalize_test.go b/internal/attachment/normalize_test.go index 084c856d..b91c37e2 100644 --- a/internal/attachment/normalize_test.go +++ b/internal/attachment/normalize_test.go @@ -48,6 +48,9 @@ func TestNormalizeMime(t *testing.T) { if got != "image/jpeg" { t.Fatalf("NormalizeMime unexpected result: %q", got) } + if got := NormalizeMime("file"); got != "" { + t.Fatalf("NormalizeMime should drop invalid mime token, got %q", got) + } } func TestMimeFromDataURL(t *testing.T) { @@ -67,6 +70,9 @@ func TestResolveMime(t *testing.T) { if got := ResolveMime(media.MediaTypeFile, "application/octet-stream", "application/pdf"); got != "application/pdf" { t.Fatalf("ResolveMime file unexpected result: %q", got) } + if got := ResolveMime(media.MediaTypeFile, "file", "text/plain"); got != "text/plain" { + t.Fatalf("ResolveMime should prefer sniffed mime for invalid source token, got %q", got) + } if got := ResolveMime(media.MediaTypeImage, "", ""); got != "application/octet-stream" { t.Fatalf("ResolveMime empty unexpected result: %q", got) } diff --git a/internal/channel/adapters/qq/client.go b/internal/channel/adapters/qq/client.go new file mode 100644 index 00000000..d61d358d --- /dev/null +++ b/internal/channel/adapters/qq/client.go @@ -0,0 +1,397 @@ +package qq + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" +) + +type qqClient struct { + appID string + clientSecret string + httpClient *http.Client + logger interface { + Debug(string, ...any) + } + apiBaseURL string + tokenURL string + + tokenMu sync.Mutex + token string + expiresAt time.Time + + msgSeqMu sync.Mutex + msgSeq map[string]int +} + +func (c *qqClient) matches(cfg Config) bool { + return c.appID == cfg.AppID && c.clientSecret == cfg.AppSecret +} + +func (c *qqClient) clearToken() { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + c.token = "" + c.expiresAt = time.Time{} +} + +func (c *qqClient) accessToken(ctx context.Context) (string, error) { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + + if c.token != "" && time.Now().Before(c.expiresAt.Add(-5*time.Minute)) { + return c.token, nil + } + + payload := map[string]string{ + "appId": c.appID, + "clientSecret": c.clientSecret, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + u, err := url.Parse(c.tokenURL) + if err != nil || (u.Scheme != "https" && !isLocalhost(u.Host)) { + return "", fmt.Errorf("invalid token url: %s", c.tokenURL) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.tokenURL, bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) //nolint:gosec // token URL is validated to https or localhost above + if err != nil { + return "", fmt.Errorf("qq token request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("qq token read: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("qq token request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(raw))) + } + + var result map[string]json.RawMessage + if err := json.Unmarshal(raw, &result); err != nil { + return "", fmt.Errorf("qq token decode: %w", err) + } + tokenBytes, ok := result["access_token"] + if !ok { + return "", errors.New("qq token response missing access_token") + } + var token string + if err := json.Unmarshal(tokenBytes, &token); err != nil { + return "", fmt.Errorf("qq token decode: %w", err) + } + if strings.TrimSpace(token) == "" { + return "", errors.New("qq token response missing access_token") + } + expiresIn := parseQQExpiresIn(result["expires_in"]) + if expiresIn <= 0 { + expiresIn = 7200 + } + c.token = token + c.expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second) + return c.token, nil +} + +func (c *qqClient) gatewayURL(ctx context.Context) (string, error) { + var result struct { + URL string `json:"url"` + } + if err := c.doJSON(ctx, http.MethodGet, "/gateway", nil, &result); err != nil { + return "", err + } + if strings.TrimSpace(result.URL) == "" { + return "", errors.New("qq gateway response missing url") + } + return result.URL, nil +} + +func (c *qqClient) nextMsgSeq(replyTo string) int { + if strings.TrimSpace(replyTo) == "" { + return 1 + } + c.msgSeqMu.Lock() + defer c.msgSeqMu.Unlock() + + next := c.msgSeq[replyTo] + 1 + c.msgSeq[replyTo] = next + if len(c.msgSeq) > 1024 { + for key := range c.msgSeq { + delete(c.msgSeq, key) + if len(c.msgSeq) <= 512 { + break + } + } + } + return next +} + +func parseQQExpiresIn(raw json.RawMessage) int { + trimmed := strings.TrimSpace(string(raw)) + if trimmed == "" || trimmed == "null" { + return 0 + } + + var numeric int + if err := json.Unmarshal(raw, &numeric); err == nil { + return numeric + } + + var text string + if err := json.Unmarshal(raw, &text); err == nil { + value, err := strconv.Atoi(strings.TrimSpace(text)) + if err == nil { + return value + } + } + + return 0 +} + +func (c *qqClient) doJSON(ctx context.Context, method, path string, payload any, out any) error { + return c.doJSONWithRetry(ctx, method, c.apiBaseURL+path, payload, out, true) +} + +func (c *qqClient) doJSONWithRetry(ctx context.Context, method, url string, payload any, out any, auth bool) error { + var lastErr error + for attempt := 0; attempt < 2; attempt++ { + lastErr = c.doJSONOnce(ctx, method, url, payload, out, auth) + if lastErr == nil { + return nil + } + if !auth || !strings.Contains(lastErr.Error(), "status=401") { + return lastErr + } + c.clearToken() + } + return lastErr +} + +func (c *qqClient) doJSONOnce(ctx context.Context, method, requestURL string, payload any, out any, auth bool) error { + var body io.Reader + if payload != nil { + encoded, err := json.Marshal(payload) + if err != nil { + return err + } + body = bytes.NewReader(encoded) + } + + u, err := url.Parse(requestURL) + if err != nil || (u.Scheme != "https" && !isLocalhost(u.Host)) { + return fmt.Errorf("invalid api url: %s", requestURL) + } + req, err := http.NewRequestWithContext(ctx, method, requestURL, body) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + if auth { + token, err := c.accessToken(ctx) + if err != nil { + return err + } + req.Header.Set("Authorization", "QQBot "+token) + } + + resp, err := c.httpClient.Do(req) //nolint:gosec // requestURL is validated to https or localhost above + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf( + "qq api request failed: method=%s url=%s status=%d body=%s", + method, + requestURL, + resp.StatusCode, + strings.TrimSpace(string(raw)), + ) + } + if out == nil || len(raw) == 0 { + return nil + } + if err := json.Unmarshal(raw, out); err != nil { + return fmt.Errorf("qq api decode failed: %w", err) + } + return nil +} + +type qqMessageResponse struct { + ID string `json:"id"` + Timestamp any `json:"timestamp"` +} + +type qqUploadResponse struct { + FileUUID string `json:"file_uuid"` + FileInfo string `json:"file_info"` + TTL int `json:"ttl"` +} + +func (c *qqClient) sendText(ctx context.Context, target qqTarget, text string, replyTo string, markdown bool) error { + text = strings.TrimSpace(text) + if text == "" { + return nil + } + + switch target.Kind { + case qqTargetC2C: + if replyTo == "" { + return c.sendProactive(ctx, "/v2/users/"+target.ID+"/messages", text, markdown) + } + body := buildReplyTextBody(text, replyTo, c.nextMsgSeq(replyTo), markdown) + return c.doJSON(ctx, http.MethodPost, "/v2/users/"+target.ID+"/messages", body, &qqMessageResponse{}) + case qqTargetGroup: + if replyTo == "" { + return c.sendProactive(ctx, "/v2/groups/"+target.ID+"/messages", text, markdown) + } + body := buildReplyTextBody(text, replyTo, c.nextMsgSeq(replyTo), markdown) + return c.doJSON(ctx, http.MethodPost, "/v2/groups/"+target.ID+"/messages", body, &qqMessageResponse{}) + case qqTargetChannel: + body := map[string]any{"content": text} + if strings.TrimSpace(replyTo) != "" { + replyID := strings.TrimSpace(replyTo) + body["msg_id"] = replyID + body["message_reference"] = map[string]any{"message_id": replyID} + } + return c.doJSON(ctx, http.MethodPost, "/channels/"+target.ID+"/messages", body, &qqMessageResponse{}) + default: + return fmt.Errorf("unsupported qq target kind: %s", target.Kind) + } +} + +func (c *qqClient) sendProactive(ctx context.Context, path, text string, markdown bool) error { + body := map[string]any{} + if markdown { + body["markdown"] = map[string]any{"content": text} + body["msg_type"] = 2 + } else { + body["content"] = text + body["msg_type"] = 0 + } + return c.doJSON(ctx, http.MethodPost, path, body, &qqMessageResponse{}) +} + +func buildReplyTextBody(text, replyTo string, seq int, markdown bool) map[string]any { + body := map[string]any{ + "msg_id": strings.TrimSpace(replyTo), + "msg_seq": seq, + } + if markdown { + body["markdown"] = map[string]any{"content": text} + body["msg_type"] = 2 + } else { + body["content"] = text + body["msg_type"] = 0 + } + return body +} + +func (c *qqClient) sendInputHint(ctx context.Context, openID, replyTo string) error { + if strings.TrimSpace(openID) == "" || strings.TrimSpace(replyTo) == "" { + return nil + } + body := map[string]any{ + "msg_type": 6, + "input_notify": map[string]any{ + "input_type": 1, + "input_second": 60, + }, + "msg_seq": c.nextMsgSeq(replyTo), + "msg_id": strings.TrimSpace(replyTo), + } + return c.doJSON(ctx, http.MethodPost, "/v2/users/"+openID+"/messages", body, nil) +} + +func (c *qqClient) uploadMedia(ctx context.Context, target qqTarget, fileType int, rawBase64, fileName string) (string, error) { + rawBase64 = strings.TrimSpace(rawBase64) + if rawBase64 == "" { + return "", errors.New("qq upload requires file_data") + } + body := map[string]any{ + "file_type": fileType, + "srv_send_msg": false, + } + body["file_data"] = rawBase64 + if fileType == qqMediaTypeFile && strings.TrimSpace(fileName) != "" { + body["file_name"] = strings.TrimSpace(fileName) + } + + var path string + switch target.Kind { + case qqTargetC2C: + path = "/v2/users/" + target.ID + "/files" + case qqTargetGroup: + path = "/v2/groups/" + target.ID + "/files" + default: + return "", fmt.Errorf("qq upload not supported for target kind: %s", target.Kind) + } + + var result qqUploadResponse + if err := c.doJSON(ctx, http.MethodPost, path, body, &result); err != nil { + return "", err + } + if strings.TrimSpace(result.FileInfo) == "" { + return "", errors.New("qq upload response missing file_info") + } + return result.FileInfo, nil +} + +func (c *qqClient) sendMedia(ctx context.Context, target qqTarget, fileInfo, replyTo, content string) error { + body := map[string]any{ + "msg_type": 7, + "media": map[string]any{ + "file_info": fileInfo, + }, + } + if strings.TrimSpace(content) != "" { + body["content"] = strings.TrimSpace(content) + } + if strings.TrimSpace(replyTo) != "" { + body["msg_id"] = strings.TrimSpace(replyTo) + body["msg_seq"] = c.nextMsgSeq(replyTo) + } else { + body["msg_seq"] = 1 + } + + switch target.Kind { + case qqTargetC2C: + return c.doJSON(ctx, http.MethodPost, "/v2/users/"+target.ID+"/messages", body, &qqMessageResponse{}) + case qqTargetGroup: + return c.doJSON(ctx, http.MethodPost, "/v2/groups/"+target.ID+"/messages", body, &qqMessageResponse{}) + default: + return fmt.Errorf("qq media send not supported for target kind: %s", target.Kind) + } +} + +func isLocalhost(host string) bool { + host = strings.ToLower(host) + if host == "localhost" || host == "127.0.0.1" || host == "::1" { + return true + } + if strings.HasPrefix(host, "127.0.0.1:") || strings.HasPrefix(host, "[::1]:") || strings.HasPrefix(host, "localhost:") { + return true + } + return false +} diff --git a/internal/channel/adapters/qq/client_test.go b/internal/channel/adapters/qq/client_test.go new file mode 100644 index 00000000..b8fcda38 --- /dev/null +++ b/internal/channel/adapters/qq/client_test.go @@ -0,0 +1,45 @@ +package qq + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestQQAccessTokenAcceptsStringExpiresIn(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/app/getAppAccessToken" { + http.NotFound(w, r) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": "7200", + }) + })) + defer server.Close() + + client := &qqClient{ + appID: "1024", + clientSecret: "secret", + httpClient: server.Client(), + tokenURL: server.URL + "/app/getAppAccessToken", + msgSeq: make(map[string]int), + } + + token, err := client.accessToken(context.Background()) + if err != nil { + t.Fatalf("access token: %v", err) + } + if token != "token-1" { + t.Fatalf("unexpected token: %q", token) + } + if remaining := time.Until(client.expiresAt); remaining < 7100*time.Second || remaining > 7200*time.Second { + t.Fatalf("unexpected token ttl: %s", remaining) + } +} diff --git a/internal/channel/adapters/qq/config.go b/internal/channel/adapters/qq/config.go new file mode 100644 index 00000000..f6d87849 --- /dev/null +++ b/internal/channel/adapters/qq/config.go @@ -0,0 +1,171 @@ +package qq + +import ( + "errors" + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +type Config struct { + AppID string + AppSecret string + MarkdownSupport bool + EnableInputHint bool +} + +type UserConfig struct { + TargetType string + TargetID string +} + +func normalizeConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseConfig(raw) + if err != nil { + return nil, err + } + return map[string]any{ + "appId": cfg.AppID, + "clientSecret": cfg.AppSecret, + "markdownSupport": cfg.MarkdownSupport, + "enableInputHint": cfg.EnableInputHint, + }, nil +} + +func normalizeUserConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return nil, err + } + return map[string]any{ + "target_type": cfg.TargetType, + "target_id": cfg.TargetID, + }, nil +} + +func resolveTarget(raw map[string]any) (string, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return "", err + } + return cfg.TargetType + ":" + cfg.TargetID, nil +} + +func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { + cfg, err := parseUserConfig(raw) + if err != nil { + return false + } + subjectID := strings.TrimSpace(criteria.SubjectID) + if cfg.TargetType == "c2c" && subjectID != "" && subjectID == cfg.TargetID { + return true + } + if cfg.TargetType == "c2c" && strings.TrimSpace(criteria.Attribute("user_openid")) == cfg.TargetID { + return true + } + if cfg.TargetType == "group" && strings.TrimSpace(criteria.Attribute("group_openid")) == cfg.TargetID { + return true + } + if cfg.TargetType == "channel" && strings.TrimSpace(criteria.Attribute("channel_id")) == cfg.TargetID { + return true + } + return false +} + +func buildUserConfig(identity channel.Identity) map[string]any { + targetID := strings.TrimSpace(identity.Attribute("user_openid")) + if targetID == "" { + targetID = strings.TrimSpace(identity.SubjectID) + } + if targetID == "" { + return map[string]any{} + } + return map[string]any{ + "target_type": "c2c", + "target_id": targetID, + } +} + +func parseConfig(raw map[string]any) (Config, error) { + appID := strings.TrimSpace(channel.ReadString(raw, "appId", "app_id")) + clientSecret := strings.TrimSpace(channel.ReadString(raw, "clientSecret", "client_secret")) + if appID == "" { + return Config{}, errors.New("qq appId is required") + } + if clientSecret == "" { + return Config{}, errors.New("qq clientSecret is required") + } + return Config{ + AppID: appID, + AppSecret: clientSecret, + MarkdownSupport: readBool(raw, true, "markdownSupport", "markdown_support"), + EnableInputHint: readBool(raw, true, "enableInputHint", "enable_input_hint"), + }, nil +} + +func parseUserConfig(raw map[string]any) (UserConfig, error) { + targetType := strings.ToLower(strings.TrimSpace(channel.ReadString(raw, "targetType", "target_type"))) + targetID := strings.TrimSpace(channel.ReadString(raw, "targetId", "target_id")) + if targetType == "" || targetID == "" { + switch { + case strings.TrimSpace(channel.ReadString(raw, "userOpenid", "user_openid")) != "": + targetType = "c2c" + targetID = strings.TrimSpace(channel.ReadString(raw, "userOpenid", "user_openid")) + case strings.TrimSpace(channel.ReadString(raw, "groupOpenid", "group_openid")) != "": + targetType = "group" + targetID = strings.TrimSpace(channel.ReadString(raw, "groupOpenid", "group_openid")) + case strings.TrimSpace(channel.ReadString(raw, "channelId", "channel_id")) != "": + targetType = "channel" + targetID = strings.TrimSpace(channel.ReadString(raw, "channelId", "channel_id")) + } + } + if targetType == "" || targetID == "" { + return UserConfig{}, errors.New("qq user config requires target_type and target_id") + } + switch targetType { + case "c2c", "group", "channel": + default: + return UserConfig{}, errors.New("qq target_type must be c2c, group, or channel") + } + return UserConfig{TargetType: targetType, TargetID: targetID}, nil +} + +func normalizeTarget(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + for _, prefix := range []string{"qq:", "qqbot:"} { + if strings.HasPrefix(strings.ToLower(value), prefix) { + value = strings.TrimSpace(value[len(prefix):]) + break + } + } + for _, targetType := range []string{"c2c:", "group:", "channel:"} { + if strings.HasPrefix(strings.ToLower(value), targetType) { + return strings.ToLower(targetType[:len(targetType)-1]) + ":" + strings.TrimSpace(value[len(targetType):]) + } + } + return "c2c:" + value +} + +func readBool(raw map[string]any, fallback bool, keys ...string) bool { + for _, key := range keys { + value, ok := raw[key] + if !ok { + continue + } + switch v := value.(type) { + case bool: + return v + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "true", "1", "yes", "on": + return true + case "false", "0", "no", "off": + return false + } + } + } + return fallback +} diff --git a/internal/channel/adapters/qq/config_test.go b/internal/channel/adapters/qq/config_test.go new file mode 100644 index 00000000..96fb92d9 --- /dev/null +++ b/internal/channel/adapters/qq/config_test.go @@ -0,0 +1,192 @@ +package qq + +import ( + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +func TestNormalizeConfig(t *testing.T) { + t.Parallel() + + got, err := normalizeConfig(map[string]any{ + "app_id": "1024", + "client_secret": "secret", + "markdown_support": true, + "enable_input_hint": false, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got["appId"] != "1024" { + t.Fatalf("unexpected appId: %#v", got["appId"]) + } + if got["clientSecret"] != "secret" { + t.Fatalf("unexpected clientSecret: %#v", got["clientSecret"]) + } + if got["markdownSupport"] != true { + t.Fatalf("unexpected markdownSupport: %#v", got["markdownSupport"]) + } + if got["enableInputHint"] != false { + t.Fatalf("unexpected enableInputHint: %#v", got["enableInputHint"]) + } +} + +func TestNormalizeConfigRequiresSecrets(t *testing.T) { + t.Parallel() + + if _, err := normalizeConfig(map[string]any{ + "client_secret": "secret", + }); err == nil { + t.Fatal("expected appId validation error") + } + if _, err := normalizeConfig(map[string]any{ + "app_id": "1024", + }); err == nil { + t.Fatal("expected clientSecret validation error") + } +} + +func TestNormalizeUserConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw map[string]any + want map[string]any + }{ + { + name: "explicit target", + raw: map[string]any{ + "target_type": "group", + "target_id": "group-openid", + }, + want: map[string]any{ + "target_type": "group", + "target_id": "group-openid", + }, + }, + { + name: "user openid alias", + raw: map[string]any{ + "user_openid": "user-openid", + }, + want: map[string]any{ + "target_type": "c2c", + "target_id": "user-openid", + }, + }, + { + name: "channel id alias", + raw: map[string]any{ + "channel_id": "12345", + }, + want: map[string]any{ + "target_type": "channel", + "target_id": "12345", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := normalizeUserConfig(tt.raw) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got["target_type"] != tt.want["target_type"] { + t.Fatalf("unexpected target_type: %#v", got["target_type"]) + } + if got["target_id"] != tt.want["target_id"] { + t.Fatalf("unexpected target_id: %#v", got["target_id"]) + } + }) + } +} + +func TestNormalizeUserConfigRequiresTarget(t *testing.T) { + t.Parallel() + + if _, err := normalizeUserConfig(map[string]any{}); err == nil { + t.Fatal("expected target validation error") + } +} + +func TestResolveTarget(t *testing.T) { + t.Parallel() + + target, err := resolveTarget(map[string]any{ + "target_type": "c2c", + "target_id": "user-openid", + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if target != "c2c:user-openid" { + t.Fatalf("unexpected target: %s", target) + } +} + +func TestNormalizeTarget(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want string + }{ + {input: "qq:group:abc", want: "group:abc"}, + {input: "qqbot:c2c:USER1", want: "c2c:USER1"}, + {input: "channel:123", want: "channel:123"}, + {input: "00112233445566778899AABBCCDDEEFF", want: "c2c:00112233445566778899AABBCCDDEEFF"}, + } + + for _, tt := range tests { + if got := normalizeTarget(tt.input); got != tt.want { + t.Fatalf("normalizeTarget(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestMatchBinding(t *testing.T) { + t.Parallel() + + config := map[string]any{ + "target_type": "c2c", + "target_id": "user-openid", + } + + if !matchBinding(config, channel.BindingCriteria{ + SubjectID: "user-openid", + }) { + t.Fatal("expected subject match") + } + + if !matchBinding(config, channel.BindingCriteria{ + Attributes: map[string]string{"user_openid": "user-openid"}, + }) { + t.Fatal("expected user_openid match") + } + + if matchBinding(config, channel.BindingCriteria{ + SubjectID: "other-user", + }) { + t.Fatal("unexpected mismatch") + } +} + +func TestBuildUserConfig(t *testing.T) { + t.Parallel() + + got := buildUserConfig(channel.Identity{ + SubjectID: "user-openid", + }) + + if got["target_type"] != "c2c" { + t.Fatalf("unexpected target_type: %#v", got["target_type"]) + } + if got["target_id"] != "user-openid" { + t.Fatalf("unexpected target_id: %#v", got["target_id"]) + } +} diff --git a/internal/channel/adapters/qq/descriptor.go b/internal/channel/adapters/qq/descriptor.go new file mode 100644 index 00000000..89b276b7 --- /dev/null +++ b/internal/channel/adapters/qq/descriptor.go @@ -0,0 +1,5 @@ +package qq + +import "github.com/memohai/memoh/internal/channel" + +const Type channel.ChannelType = "qq" diff --git a/internal/channel/adapters/qq/face_tags.go b/internal/channel/adapters/qq/face_tags.go new file mode 100644 index 00000000..e94e18f6 --- /dev/null +++ b/internal/channel/adapters/qq/face_tags.go @@ -0,0 +1,31 @@ +package qq + +import ( + "encoding/base64" + "encoding/json" + "errors" + "regexp" +) + +var faceTagPattern = regexp.MustCompile(``) + +func decodeFaceTag(raw string) (string, error) { + matches := faceTagPattern.FindStringSubmatch(raw) + if len(matches) < 2 { + return "", errors.New("qq face tag ext is missing") + } + decoded, err := base64.StdEncoding.DecodeString(matches[1]) + if err != nil { + return "", err + } + var payload struct { + Text string `json:"text"` + } + if err := json.Unmarshal(decoded, &payload); err != nil { + return "", err + } + if payload.Text == "" { + return "", errors.New("qq face tag text is empty") + } + return payload.Text, nil +} diff --git a/internal/channel/adapters/qq/factory.go b/internal/channel/adapters/qq/factory.go new file mode 100644 index 00000000..3afda404 --- /dev/null +++ b/internal/channel/adapters/qq/factory.go @@ -0,0 +1,18 @@ +package qq + +import ( + "log/slog" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/channel/identities" + "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/media" +) + +func ProvideQQAdapter(log *slog.Logger, mediaService *media.Service, identityService *identities.Service, routeService *route.DBService) channel.Adapter { + adapter := NewQQAdapter(log) + adapter.SetAssetOpener(mediaService) + adapter.SetChannelIdentityResolver(identityService) + adapter.SetRouteResolver(routeService) + return adapter +} diff --git a/internal/channel/adapters/qq/qq.go b/internal/channel/adapters/qq/qq.go new file mode 100644 index 00000000..8da6d3f1 --- /dev/null +++ b/internal/channel/adapters/qq/qq.go @@ -0,0 +1,264 @@ +package qq + +import ( + "context" + "io" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/memohai/memoh/internal/channel" + identitypkg "github.com/memohai/memoh/internal/channel/identities" + routepkg "github.com/memohai/memoh/internal/channel/route" + "github.com/memohai/memoh/internal/media" +) + +const ( + defaultAPIBaseURL = "https://api.sgroup.qq.com" + qqOAuthEndpoint = "https://bots.qq.com/app/getAppAccessToken" + defaultChunkLimit = 2000 + defaultReadTimeout = 45 * time.Second + defaultWriteTimeout = 15 * time.Second +) + +type assetOpener interface { + Open(ctx context.Context, botID, contentHash string) (io.ReadCloser, media.Asset, error) +} + +type sessionState struct { + SessionID string + LastSeq int + IntentLevel int +} + +type channelIdentityResolver interface { + GetByID(ctx context.Context, channelIdentityID string) (identitypkg.ChannelIdentity, error) + ListCanonicalChannelIdentities(ctx context.Context, channelIdentityID string) ([]identitypkg.ChannelIdentity, error) + ListUserChannelIdentities(ctx context.Context, userID string) ([]identitypkg.ChannelIdentity, error) +} + +type routeResolver interface { + GetByID(ctx context.Context, routeID string) (routepkg.Route, error) +} + +type QQAdapter struct { + logger *slog.Logger + httpClient *http.Client + dialer *websocket.Dialer + apiBaseURL string + tokenURL string + + mu sync.Mutex + clients map[string]*qqClient + sessions map[string]sessionState + assets assetOpener + identity channelIdentityResolver + routes routeResolver +} + +func NewQQAdapter(log *slog.Logger) *QQAdapter { + if log == nil { + log = slog.Default() + } + return &QQAdapter{ + logger: log.With(slog.String("adapter", "qq")), + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + dialer: &websocket.Dialer{ + HandshakeTimeout: 15 * time.Second, + }, + apiBaseURL: defaultAPIBaseURL, + tokenURL: qqOAuthEndpoint, + clients: make(map[string]*qqClient), + sessions: make(map[string]sessionState), + } +} + +func (*QQAdapter) Type() channel.ChannelType { + return Type +} + +func (*QQAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: Type, + DisplayName: "QQ", + Capabilities: channel.ChannelCapabilities{ + Text: true, + Markdown: true, + Attachments: true, + Media: true, + Reply: true, + BlockStreaming: true, + ChatTypes: []string{"direct", "group", "channel"}, + }, + OutboundPolicy: channel.OutboundPolicy{ + TextChunkLimit: defaultChunkLimit, + ChunkerMode: channel.ChunkerModeMarkdown, + MediaOrder: channel.OutboundOrderTextFirst, + InlineTextWithMedia: true, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "appId": { + Type: channel.FieldString, + Required: true, + Title: "App ID", + }, + "clientSecret": { + Type: channel.FieldSecret, + Required: true, + Title: "Client Secret", + }, + "markdownSupport": { + Type: channel.FieldBool, + Title: "Markdown Support", + Description: "Enable QQ markdown message mode for C2C and group replies when the bot has permission.", + }, + "enableInputHint": { + Type: channel.FieldBool, + Title: "Input Hint", + Description: "Send QQ input-notify hints for direct messages while the bot is processing.", + }, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "target_type": { + Type: channel.FieldEnum, + Required: true, + Title: "Target Type", + Enum: []string{"c2c", "group", "channel"}, + }, + "target_id": { + Type: channel.FieldString, + Required: true, + Title: "Target ID", + }, + }, + }, + TargetSpec: channel.TargetSpec{ + Format: "c2c: | group: | channel:", + Hints: []channel.TargetHint{ + {Label: "Direct", Example: "c2c:00112233445566778899AABBCCDDEEFF"}, + {Label: "Group", Example: "group:00112233445566778899AABBCCDDEEFF"}, + {Label: "Channel", Example: "channel:1234567890"}, + }, + }, + } +} + +func (*QQAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { + return normalizeConfig(raw) +} + +func (*QQAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + return normalizeUserConfig(raw) +} + +func (*QQAdapter) NormalizeTarget(raw string) string { + return normalizeTarget(raw) +} + +func (*QQAdapter) ResolveTarget(userConfig map[string]any) (string, error) { + return resolveTarget(userConfig) +} + +func (*QQAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { + return matchBinding(config, criteria) +} + +func (*QQAdapter) BuildUserConfig(identity channel.Identity) map[string]any { + return buildUserConfig(identity) +} + +func (a *QQAdapter) SetAssetOpener(opener assetOpener) { + a.mu.Lock() + defer a.mu.Unlock() + a.assets = opener +} + +func (a *QQAdapter) SetChannelIdentityResolver(resolver channelIdentityResolver) { + a.mu.Lock() + defer a.mu.Unlock() + a.identity = resolver +} + +func (a *QQAdapter) SetRouteResolver(resolver routeResolver) { + a.mu.Lock() + defer a.mu.Unlock() + a.routes = resolver +} + +func (a *QQAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { + parsed, err := parseConfig(cfg.Credentials) + if err != nil { + return channel.ProcessingStatusHandle{}, err + } + if !parsed.EnableInputHint || strings.TrimSpace(info.SourceMessageID) == "" { + return channel.ProcessingStatusHandle{}, nil + } + target, err := parseTarget(info.ReplyTarget) + if err != nil || target.Kind != qqTargetC2C { + return channel.ProcessingStatusHandle{}, nil + } + client := a.getOrCreateClient(cfg, parsed) + if err := client.sendInputHint(ctx, target.ID, info.SourceMessageID); err != nil { + return channel.ProcessingStatusHandle{}, err + } + return channel.ProcessingStatusHandle{}, nil +} + +func (*QQAdapter) ProcessingCompleted(context.Context, channel.ChannelConfig, channel.InboundMessage, channel.ProcessingStatusInfo, channel.ProcessingStatusHandle) error { + return nil +} + +func (*QQAdapter) ProcessingFailed(context.Context, channel.ChannelConfig, channel.InboundMessage, channel.ProcessingStatusInfo, channel.ProcessingStatusHandle, error) error { + return nil +} + +func (a *QQAdapter) getOrCreateClient(cfg channel.ChannelConfig, parsed Config) *qqClient { + a.mu.Lock() + defer a.mu.Unlock() + + existing, ok := a.clients[cfg.ID] + if ok && existing.matches(parsed) { + return existing + } + + client := &qqClient{ + appID: parsed.AppID, + clientSecret: parsed.AppSecret, + httpClient: a.httpClient, + logger: a.logger, + apiBaseURL: a.apiBaseURL, + tokenURL: a.tokenURL, + msgSeq: make(map[string]int), + } + a.clients[cfg.ID] = client + return client +} + +func (a *QQAdapter) loadSession(configID string) sessionState { + a.mu.Lock() + defer a.mu.Unlock() + return a.sessions[configID] +} + +func (a *QQAdapter) saveSession(configID string, state sessionState) { + a.mu.Lock() + defer a.mu.Unlock() + a.sessions[configID] = state +} + +func (a *QQAdapter) clearSession(configID string) { + a.mu.Lock() + defer a.mu.Unlock() + delete(a.sessions, configID) +} diff --git a/internal/channel/adapters/qq/receive.go b/internal/channel/adapters/qq/receive.go new file mode 100644 index 00000000..68b707ab --- /dev/null +++ b/internal/channel/adapters/qq/receive.go @@ -0,0 +1,628 @@ +package qq + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/memohai/memoh/internal/channel" +) + +const ( + qqIntentGuilds = 1 << 0 + qqIntentGuildMembers = 1 << 1 + qqIntentPublicGuildMessages = 1 << 30 + qqIntentGroupAndC2C = 1 << 25 +) + +var qqIntentLevels = []int{ + qqIntentPublicGuildMessages | qqIntentGroupAndC2C, + qqIntentPublicGuildMessages | qqIntentGuildMembers, +} + +type MessageAttachment struct { + ContentType string `json:"content_type"` + FileName string `json:"filename,omitempty"` + Height int `json:"height,omitempty"` + Width int `json:"width,omitempty"` + Size int64 `json:"size,omitempty"` + URL string `json:"url"` + VoiceWavURL string `json:"voice_wav_url,omitempty"` +} + +type C2CAuthor struct { + ID string `json:"id,omitempty"` + UnionOpenID string `json:"union_openid,omitempty"` + UserOpenID string `json:"user_openid"` +} + +type GroupAuthor struct { + ID string `json:"id,omitempty"` + MemberOpenID string `json:"member_openid"` +} + +type GuildAuthor struct { + ID string `json:"id"` + Username string `json:"username,omitempty"` + Bot bool `json:"bot,omitempty"` +} + +type C2CMessageEvent struct { + Author C2CAuthor `json:"author"` + Content string `json:"content"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + Attachments []MessageAttachment `json:"attachments,omitempty"` +} + +type GroupMessageEvent struct { + Author GroupAuthor `json:"author"` + Content string `json:"content"` + ID string `json:"id"` + Timestamp string `json:"timestamp"` + GroupID string `json:"group_id,omitempty"` + GroupOpenID string `json:"group_openid"` + Attachments []MessageAttachment `json:"attachments,omitempty"` +} + +type GuildMessageEvent struct { + ID string `json:"id"` + ChannelID string `json:"channel_id"` + GuildID string `json:"guild_id,omitempty"` + Content string `json:"content"` + Timestamp string `json:"timestamp"` + Author GuildAuthor `json:"author"` + Attachments []MessageAttachment `json:"attachments,omitempty"` +} + +type wsPayload struct { + Op int `json:"op"` + D json.RawMessage `json:"d,omitempty"` + S int `json:"s,omitempty"` + T string `json:"t,omitempty"` +} + +type InboundEvent struct { + Type string + C2CMessage *C2CMessageEvent + GroupMessage *GroupMessageEvent + GuildMessage *GuildMessageEvent +} + +type gatewayWriter struct { + conn *websocket.Conn + mu sync.Mutex +} + +type heartbeatHandle struct { + cancel context.CancelFunc + done <-chan struct{} +} + +func (w *gatewayWriter) WriteJSON(v any) error { + w.mu.Lock() + defer w.mu.Unlock() + _ = w.conn.SetWriteDeadline(time.Now().Add(defaultWriteTimeout)) + return w.conn.WriteJSON(v) +} + +func (a *QQAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { + parsed, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + + connCtx, cancel := context.WithCancel(ctx) + go a.runReceiver(connCtx, cfg, parsed, handler) + + return channel.NewConnection(cfg, func(context.Context) error { + cancel() + return nil + }), nil +} + +func (a *QQAdapter) runReceiver(ctx context.Context, cfg channel.ChannelConfig, parsed Config, handler channel.InboundHandler) { + backoffs := []time.Duration{time.Second, 2 * time.Second, 5 * time.Second, 10 * time.Second, 30 * time.Second} + attempt := 0 + for ctx.Err() == nil { + healthySession, err := a.serveConnection(ctx, cfg, parsed, handler) + if err == nil || ctx.Err() != nil { + return + } + if a.logger != nil { + a.logger.Warn("qq receiver reconnect", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + delay, nextAttempt := nextReconnectDelay(backoffs, attempt, healthySession) + attempt = nextAttempt + if !sleepContext(ctx, delay) { + return + } + } +} + +func (a *QQAdapter) serveConnection(ctx context.Context, cfg channel.ChannelConfig, parsed Config, handler channel.InboundHandler) (bool, error) { + client := a.getOrCreateClient(cfg, parsed) + gatewayURL, err := client.gatewayURL(ctx) + if err != nil { + return false, err + } + + conn, resp, err := a.dialer.DialContext(ctx, gatewayURL, nil) + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + if err != nil { + return false, err + } + defer func() { _ = conn.Close() }() + _ = conn.SetReadDeadline(time.Now().Add(defaultReadTimeout)) + writer := &gatewayWriter{conn: conn} + + session := a.loadSession(cfg.ID) + var heartbeatSeq atomic.Int64 + heartbeatSeq.Store(int64(session.LastSeq)) + healthySession := false + + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = conn.Close() + case <-done: + } + }() + defer close(done) + + var heartbeat heartbeatHandle + defer func() { + if heartbeat.cancel != nil { + heartbeat.cancel() + } + }() + + for { + _, data, err := conn.ReadMessage() + if err != nil { + var closeErr *websocket.CloseError + if errors.As(err, &closeErr) { + return a.handleGatewayClose(cfg.ID, client, &session, closeErr, healthySession) + } + return healthySession, err + } + _ = conn.SetReadDeadline(time.Now().Add(defaultReadTimeout)) + + var payload wsPayload + if err := json.Unmarshal(data, &payload); err != nil { + return healthySession, fmt.Errorf("qq websocket payload decode: %w", err) + } + if payload.S > 0 { + session.LastSeq = payload.S + a.saveSession(cfg.ID, session) + heartbeatSeq.Store(int64(payload.S)) + } + + switch payload.Op { + case 10: + if err := handleHello(ctx, writer, client, &session, payload.D); err != nil { + return healthySession, err + } + if heartbeat.cancel != nil { + heartbeat.cancel() + } + interval := parseHeartbeatInterval(payload.D) + heartbeat = startHeartbeat(ctx, writer, interval, func() int { + return int(heartbeatSeq.Load()) + }) + case 0: + dispatchHealthy, err := a.handleDispatch(ctx, cfg, handler, payload.T, payload.D, &session) + if err != nil { + return healthySession, err + } + healthySession = healthySession || dispatchHealthy + case 7: + return healthySession, errors.New("qq gateway requested reconnect") + case 9: + a.adjustSessionAfterInvalid(cfg.ID, &session) + return healthySession, errors.New("qq invalid session") + case 11: + continue + } + } +} + +func (a *QQAdapter) handleGatewayClose(configID string, client *qqClient, session *sessionState, closeErr *websocket.CloseError, healthySession bool) (bool, error) { + switch closeErr.Code { + case 4004: + if client != nil { + client.clearToken() + } + case 4006, 4007, 4009: + a.clearSession(configID) + case 4914, 4915: + a.adjustSessionAfterIntentClose(configID, session) + return healthySession, fmt.Errorf("qq gateway closed with intent code %d", closeErr.Code) + } + return healthySession, closeErr +} + +func handleHello(ctx context.Context, writer *gatewayWriter, client *qqClient, session *sessionState, _ json.RawMessage) error { + token, err := client.accessToken(ctx) + if err != nil { + return err + } + if session.SessionID != "" && session.LastSeq > 0 { + payload := map[string]any{ + "op": 6, + "d": map[string]any{ + "token": "QQBot " + token, + "session_id": session.SessionID, + "seq": session.LastSeq, + }, + } + return writer.WriteJSON(payload) + } + intentLevel := session.IntentLevel + if intentLevel < 0 || intentLevel >= len(qqIntentLevels) { + intentLevel = 0 + } + session.IntentLevel = intentLevel + return writer.WriteJSON(map[string]any{ + "op": 2, + "d": map[string]any{ + "token": "QQBot " + token, + "intents": qqIntentLevels[intentLevel], + "shard": []int{0, 1}, + }, + }) +} + +func (a *QQAdapter) handleDispatch(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler, eventType string, raw json.RawMessage, session *sessionState) (bool, error) { + switch eventType { + case "READY": + var ready struct { + SessionID string `json:"session_id"` + } + if err := json.Unmarshal(raw, &ready); err != nil { + return false, err + } + session.SessionID = strings.TrimSpace(ready.SessionID) + a.saveSession(cfg.ID, *session) + return true, nil + case "RESUMED": + a.saveSession(cfg.ID, *session) + return true, nil + case "C2C_MESSAGE_CREATE": + var event C2CMessageEvent + if err := json.Unmarshal(raw, &event); err != nil { + return false, err + } + a.dispatchInbound(ctx, cfg, handler, InboundEvent{Type: eventType, C2CMessage: &event}) + return false, nil + case "GROUP_AT_MESSAGE_CREATE": + var event GroupMessageEvent + if err := json.Unmarshal(raw, &event); err != nil { + return false, err + } + a.dispatchInbound(ctx, cfg, handler, InboundEvent{Type: eventType, GroupMessage: &event}) + return false, nil + case "AT_MESSAGE_CREATE": + var event GuildMessageEvent + if err := json.Unmarshal(raw, &event); err != nil { + return false, err + } + a.dispatchInbound(ctx, cfg, handler, InboundEvent{Type: eventType, GuildMessage: &event}) + return false, nil + default: + return false, nil + } +} + +func (a *QQAdapter) dispatchInbound(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler, inbound InboundEvent) { + msg, ok := eventToInboundMessage(inbound, cfg.BotID) + if !ok { + return + } + go func() { + if err := handler(ctx, cfg, msg); err != nil && a.logger != nil { + a.logger.Error("qq handle inbound failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) + } + }() +} + +func startHeartbeat(parent context.Context, writer *gatewayWriter, interval time.Duration, seqValue func() int) heartbeatHandle { + ctx, cancel := context.WithCancel(parent) + ticker := time.NewTicker(interval) + done := make(chan struct{}) + go runHeartbeat(ctx, writer, ticker, seqValue, done) + return heartbeatHandle{ + cancel: cancel, + done: done, + } +} + +func runHeartbeat(ctx context.Context, writer *gatewayWriter, ticker *time.Ticker, seqValue func() int, done chan<- struct{}) { + defer ticker.Stop() + if done != nil { + defer close(done) + } + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _ = writer.WriteJSON(map[string]any{ + "op": 1, + "d": seqValue(), + }) + } + } +} + +func (a *QQAdapter) adjustSessionAfterInvalid(configID string, session *sessionState) { + session.SessionID = "" + session.LastSeq = 0 + a.saveSession(configID, *session) +} + +func (a *QQAdapter) adjustSessionAfterIntentClose(configID string, session *sessionState) { + session.SessionID = "" + session.LastSeq = 0 + if last := len(qqIntentLevels) - 1; last >= 0 { + switch { + case session.IntentLevel < 0: + session.IntentLevel = 0 + case session.IntentLevel < last: + session.IntentLevel++ + default: + session.IntentLevel = last + } + } + a.saveSession(configID, *session) +} + +func parseHeartbeatInterval(raw json.RawMessage) time.Duration { + var hello struct { + HeartbeatInterval int `json:"heartbeat_interval"` + } + if err := json.Unmarshal(raw, &hello); err != nil { + return 30 * time.Second + } + if hello.HeartbeatInterval <= 0 { + return 30 * time.Second + } + return time.Duration(hello.HeartbeatInterval) * time.Millisecond +} + +func eventToInboundMessage(event InboundEvent, botID string) (channel.InboundMessage, bool) { + switch event.Type { + case "C2C_MESSAGE_CREATE": + if event.C2CMessage == nil { + return channel.InboundMessage{}, false + } + payload := event.C2CMessage + subjectID := strings.TrimSpace(payload.Author.UserOpenID) + if subjectID == "" { + return channel.InboundMessage{}, false + } + return channel.InboundMessage{ + Channel: Type, + BotID: strings.TrimSpace(botID), + Message: channel.Message{ + ID: strings.TrimSpace(payload.ID), + Format: channel.MessageFormatPlain, + Text: parseFaceTags(strings.TrimSpace(payload.Content)), + Attachments: toInboundAttachments(payload.Attachments), + }, + ReplyTarget: "c2c:" + subjectID, + Sender: channel.Identity{ + SubjectID: subjectID, + Attributes: map[string]string{ + "user_openid": subjectID, + "union_openid": strings.TrimSpace(payload.Author.UnionOpenID), + }, + }, + Conversation: channel.Conversation{ + ID: subjectID, + Type: "direct", + }, + ReceivedAt: parseTimestamp(payload.Timestamp), + Source: "qq", + Metadata: map[string]any{ + "is_mentioned": false, + }, + }, true + case "GROUP_AT_MESSAGE_CREATE": + if event.GroupMessage == nil { + return channel.InboundMessage{}, false + } + payload := event.GroupMessage + subjectID := strings.TrimSpace(payload.Author.MemberOpenID) + groupID := strings.TrimSpace(payload.GroupOpenID) + if subjectID == "" || groupID == "" { + return channel.InboundMessage{}, false + } + return channel.InboundMessage{ + Channel: Type, + BotID: strings.TrimSpace(botID), + Message: channel.Message{ + ID: strings.TrimSpace(payload.ID), + Format: channel.MessageFormatPlain, + Text: parseFaceTags(strings.TrimSpace(payload.Content)), + Attachments: toInboundAttachments(payload.Attachments), + }, + ReplyTarget: "group:" + groupID, + Sender: channel.Identity{ + SubjectID: subjectID, + Attributes: map[string]string{ + "user_openid": subjectID, + "group_openid": groupID, + }, + }, + Conversation: channel.Conversation{ + ID: groupID, + Type: "group", + }, + ReceivedAt: parseTimestamp(payload.Timestamp), + Source: "qq", + Metadata: map[string]any{ + "is_mentioned": true, + "group_id": strings.TrimSpace(payload.GroupID), + "group_openid": groupID, + }, + }, true + case "AT_MESSAGE_CREATE": + if event.GuildMessage == nil { + return channel.InboundMessage{}, false + } + payload := event.GuildMessage + subjectID := strings.TrimSpace(payload.Author.ID) + channelID := strings.TrimSpace(payload.ChannelID) + if subjectID == "" || channelID == "" { + return channel.InboundMessage{}, false + } + return channel.InboundMessage{ + Channel: Type, + BotID: strings.TrimSpace(botID), + Message: channel.Message{ + ID: strings.TrimSpace(payload.ID), + Format: channel.MessageFormatPlain, + Text: parseFaceTags(strings.TrimSpace(payload.Content)), + Attachments: toInboundAttachments(payload.Attachments), + }, + ReplyTarget: "channel:" + channelID, + Sender: channel.Identity{ + SubjectID: subjectID, + DisplayName: strings.TrimSpace(payload.Author.Username), + Attributes: map[string]string{ + "user_id": subjectID, + "channel_id": channelID, + "guild_id": strings.TrimSpace(payload.GuildID), + }, + }, + Conversation: channel.Conversation{ + ID: channelID, + Type: "channel", + }, + ReceivedAt: parseTimestamp(payload.Timestamp), + Source: "qq", + Metadata: map[string]any{ + "is_mentioned": true, + "guild_id": strings.TrimSpace(payload.GuildID), + "channel_id": channelID, + }, + }, true + default: + return channel.InboundMessage{}, false + } +} + +func toInboundAttachments(items []MessageAttachment) []channel.Attachment { + if len(items) == 0 { + return nil + } + result := make([]channel.Attachment, 0, len(items)) + for _, item := range items { + attachmentURL := normalizeQQURL(item.URL) + attType := inferAttachmentType(item) + if attType == channel.AttachmentVoice && strings.TrimSpace(item.VoiceWavURL) != "" { + attachmentURL = normalizeQQURL(item.VoiceWavURL) + } + att := channel.NormalizeInboundChannelAttachment(channel.Attachment{ + Type: attType, + URL: attachmentURL, + Name: strings.TrimSpace(item.FileName), + Mime: strings.TrimSpace(item.ContentType), + Size: item.Size, + Width: item.Width, + Height: item.Height, + SourcePlatform: Type.String(), + Metadata: map[string]any{ + "voice_wav_url": normalizeQQURL(item.VoiceWavURL), + }, + }) + result = append(result, att) + } + return result +} + +func inferAttachmentType(att MessageAttachment) channel.AttachmentType { + contentType := strings.ToLower(strings.TrimSpace(att.ContentType)) + name := strings.ToLower(strings.TrimSpace(att.FileName)) + switch { + case strings.HasPrefix(contentType, "image/gif"), strings.HasSuffix(name, ".gif"): + return channel.AttachmentGIF + case strings.HasPrefix(contentType, "image/"): + return channel.AttachmentImage + case strings.HasPrefix(contentType, "video/"): + return channel.AttachmentVideo + case strings.HasPrefix(contentType, "audio/"), strings.TrimSpace(att.VoiceWavURL) != "": + return channel.AttachmentVoice + default: + return channel.AttachmentFile + } +} + +func normalizeQQURL(raw string) string { + value := strings.TrimSpace(raw) + if strings.HasPrefix(value, "//") { + return "https:" + value + } + return value +} + +func parseTimestamp(raw string) time.Time { + if ts, err := time.Parse(time.RFC3339, strings.TrimSpace(raw)); err == nil { + return ts.UTC() + } + return time.Now().UTC() +} + +func parseFaceTags(text string) string { + if strings.TrimSpace(text) == "" { + return "" + } + return faceTagPattern.ReplaceAllStringFunc(text, func(match string) string { + value, err := decodeFaceTag(match) + if err != nil { + return match + } + return "【表情: " + value + "】" + }) +} + +func sleepContext(ctx context.Context, delay time.Duration) bool { + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + +func nextReconnectDelay(backoffs []time.Duration, attempt int, healthySession bool) (time.Duration, int) { + if len(backoffs) == 0 { + return 0, attempt + } + if healthySession { + attempt = 0 + } + delay := backoffs[intMin(attempt, len(backoffs)-1)] + return delay, attempt + 1 +} + +func intMin(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/channel/adapters/qq/receive_test.go b/internal/channel/adapters/qq/receive_test.go new file mode 100644 index 00000000..41306539 --- /dev/null +++ b/internal/channel/adapters/qq/receive_test.go @@ -0,0 +1,330 @@ +package qq + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + + "github.com/memohai/memoh/internal/channel" +) + +func TestEventToInboundMessageC2C(t *testing.T) { + t.Parallel() + + msg, ok := eventToInboundMessage(InboundEvent{ + Type: "C2C_MESSAGE_CREATE", + C2CMessage: &C2CMessageEvent{ + ID: "msg-1", + Content: "hello", + Timestamp: "2026-03-06T12:00:00Z", + Author: C2CAuthor{ + UserOpenID: "user-openid", + }, + Attachments: []MessageAttachment{{ + ContentType: "image/png", + URL: "//cdn.qq.com/image.png", + FileName: "a.png", + Width: 120, + Height: 80, + Size: 2048, + }}, + }, + }, "bot-1") + if !ok { + t.Fatal("expected inbound message") + } + if msg.Channel != Type { + t.Fatalf("unexpected channel: %s", msg.Channel) + } + if msg.BotID != "bot-1" { + t.Fatalf("unexpected bot id: %s", msg.BotID) + } + if msg.ReplyTarget != "c2c:user-openid" { + t.Fatalf("unexpected reply target: %s", msg.ReplyTarget) + } + if msg.Conversation.Type != "direct" { + t.Fatalf("unexpected conversation type: %s", msg.Conversation.Type) + } + if msg.Sender.SubjectID != "user-openid" { + t.Fatalf("unexpected sender subject: %s", msg.Sender.SubjectID) + } + if len(msg.Message.Attachments) != 1 { + t.Fatalf("unexpected attachments: %d", len(msg.Message.Attachments)) + } + att := msg.Message.Attachments[0] + if att.Type != channel.AttachmentImage { + t.Fatalf("unexpected attachment type: %s", att.Type) + } + if att.URL != "https://cdn.qq.com/image.png" { + t.Fatalf("unexpected attachment url: %s", att.URL) + } + if mentioned, _ := msg.Metadata["is_mentioned"].(bool); mentioned { + t.Fatal("direct message should not be marked mentioned") + } +} + +func TestEventToInboundMessageGroupAt(t *testing.T) { + t.Parallel() + + msg, ok := eventToInboundMessage(InboundEvent{ + Type: "GROUP_AT_MESSAGE_CREATE", + GroupMessage: &GroupMessageEvent{ + ID: "msg-2", + Content: "@bot hi", + Timestamp: "2026-03-06T12:00:00Z", + GroupOpenID: "group-openid", + Author: GroupAuthor{ + MemberOpenID: "member-openid", + }, + }, + }, "bot-2") + if !ok { + t.Fatal("expected inbound message") + } + if msg.ReplyTarget != "group:group-openid" { + t.Fatalf("unexpected reply target: %s", msg.ReplyTarget) + } + if msg.Conversation.ID != "group-openid" { + t.Fatalf("unexpected conversation id: %s", msg.Conversation.ID) + } + if msg.Conversation.Type != "group" { + t.Fatalf("unexpected conversation type: %s", msg.Conversation.Type) + } + if msg.Sender.SubjectID != "member-openid" { + t.Fatalf("unexpected sender subject: %s", msg.Sender.SubjectID) + } + if mentioned, _ := msg.Metadata["is_mentioned"].(bool); !mentioned { + t.Fatal("group at message should be marked mentioned") + } +} + +func TestEventToInboundMessageChannelAt(t *testing.T) { + t.Parallel() + + msg, ok := eventToInboundMessage(InboundEvent{ + Type: "AT_MESSAGE_CREATE", + GuildMessage: &GuildMessageEvent{ + ID: "msg-3", + Content: "<@bot> hi", + Timestamp: "2026-03-06T12:00:00Z", + ChannelID: "channel-1", + GuildID: "guild-1", + Author: GuildAuthor{ + ID: "author-1", + Username: "alice", + }, + }, + }, "bot-3") + if !ok { + t.Fatal("expected inbound message") + } + if msg.ReplyTarget != "channel:channel-1" { + t.Fatalf("unexpected reply target: %s", msg.ReplyTarget) + } + if msg.Conversation.Type != "channel" { + t.Fatalf("unexpected conversation type: %s", msg.Conversation.Type) + } + if msg.Sender.DisplayName != "alice" { + t.Fatalf("unexpected sender display name: %s", msg.Sender.DisplayName) + } + if msg.Sender.Attribute("channel_id") != "channel-1" { + t.Fatalf("unexpected channel_id attribute: %s", msg.Sender.Attribute("channel_id")) + } + if msg.Metadata["guild_id"] != "guild-1" { + t.Fatalf("unexpected guild_id metadata: %#v", msg.Metadata["guild_id"]) + } + if mentioned, _ := msg.Metadata["is_mentioned"].(bool); !mentioned { + t.Fatal("channel at message should be marked mentioned") + } +} + +func TestEventToInboundMessageIgnoresUnsupportedType(t *testing.T) { + t.Parallel() + + if _, ok := eventToInboundMessage(InboundEvent{Type: "READY"}, "bot-1"); ok { + t.Fatal("unexpected inbound message for READY") + } +} + +func TestEventToInboundMessagePreservesGIFType(t *testing.T) { + t.Parallel() + + msg, ok := eventToInboundMessage(InboundEvent{ + Type: "C2C_MESSAGE_CREATE", + C2CMessage: &C2CMessageEvent{ + ID: "msg-gif", + Content: "gif", + Timestamp: "2026-03-06T12:00:00Z", + Author: C2CAuthor{ + UserOpenID: "user-openid", + }, + Attachments: []MessageAttachment{{ + ContentType: "image/gif", + URL: "https://cdn.qq.com/animated.gif", + FileName: "animated.gif", + }}, + }, + }, "bot-gif") + if !ok { + t.Fatal("expected inbound message") + } + if len(msg.Message.Attachments) != 1 { + t.Fatalf("unexpected attachments: %d", len(msg.Message.Attachments)) + } + if msg.Message.Attachments[0].Type != channel.AttachmentGIF { + t.Fatalf("unexpected attachment type: %s", msg.Message.Attachments[0].Type) + } +} + +func TestAdjustSessionAfterInvalidKeepsIntentLevel(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + session := sessionState{ + SessionID: "session-1", + LastSeq: 42, + IntentLevel: 0, + } + + adapter.adjustSessionAfterInvalid("cfg-1", &session) + + if session.SessionID != "" { + t.Fatalf("unexpected session id: %q", session.SessionID) + } + if session.LastSeq != 0 { + t.Fatalf("unexpected seq: %d", session.LastSeq) + } + if session.IntentLevel != 0 { + t.Fatalf("unexpected intent level: %d", session.IntentLevel) + } + + saved := adapter.loadSession("cfg-1") + if saved.IntentLevel != 0 { + t.Fatalf("unexpected saved intent level: %d", saved.IntentLevel) + } +} + +func TestStartHeartbeatCancelStopsSessionLoop(t *testing.T) { + t.Parallel() + + heartbeat := startHeartbeat(context.Background(), &gatewayWriter{}, time.Hour, func() int { return 0 }) + heartbeat.cancel() + + select { + case <-heartbeat.done: + case <-time.After(time.Second): + t.Fatal("heartbeat did not stop after session cancel") + } +} + +func TestHandleDispatchMarksHealthySessionForReadyAndResumed(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + cfg := channel.ChannelConfig{ID: "cfg-healthy", BotID: "bot-healthy"} + + session := sessionState{} + healthy, err := adapter.handleDispatch(context.Background(), cfg, func(context.Context, channel.ChannelConfig, channel.InboundMessage) error { + return nil + }, "READY", []byte(`{"session_id":"session-1"}`), &session) + if err != nil { + t.Fatalf("handle ready: %v", err) + } + if !healthy { + t.Fatal("expected READY to mark session healthy") + } + + healthy, err = adapter.handleDispatch(context.Background(), cfg, func(context.Context, channel.ChannelConfig, channel.InboundMessage) error { + return nil + }, "RESUMED", []byte(`{}`), &session) + if err != nil { + t.Fatalf("handle resumed: %v", err) + } + if !healthy { + t.Fatal("expected RESUMED to mark session healthy") + } +} + +func TestNextReconnectDelayResetsAfterHealthySession(t *testing.T) { + t.Parallel() + + backoffs := []time.Duration{time.Second, 2 * time.Second, 5 * time.Second} + delay, attempt := nextReconnectDelay(backoffs, 2, true) + + if delay != time.Second { + t.Fatalf("unexpected delay: %v", delay) + } + if attempt != 1 { + t.Fatalf("unexpected next attempt: %d", attempt) + } +} + +func TestHandleGatewayClose_IntentCodesRequireReconnect(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + session := sessionState{ + SessionID: "session-1", + LastSeq: 42, + IntentLevel: 0, + } + + healthy, err := adapter.handleGatewayClose( + "cfg-intent", + &qqClient{}, + &session, + &websocket.CloseError{Code: 4914}, + true, + ) + if !healthy { + t.Fatal("expected healthy flag to be preserved") + } + if err == nil { + t.Fatal("expected reconnect error") + } + if !strings.Contains(err.Error(), "intent code 4914") { + t.Fatalf("unexpected error: %v", err) + } + if session.SessionID != "" || session.LastSeq != 0 { + t.Fatalf("session should be reset, got id=%q seq=%d", session.SessionID, session.LastSeq) + } + if session.IntentLevel != 1 { + t.Fatalf("expected intent fallback level 1, got %d", session.IntentLevel) + } + + saved := adapter.loadSession("cfg-intent") + if saved.SessionID != "" || saved.LastSeq != 0 { + t.Fatalf("saved session should be reset, got id=%q seq=%d", saved.SessionID, saved.LastSeq) + } + if saved.IntentLevel != session.IntentLevel { + t.Fatalf("unexpected intent level: %d", saved.IntentLevel) + } +} + +func TestAdjustSessionAfterIntentCloseCapsIntentLevel(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + session := sessionState{ + SessionID: "session-2", + LastSeq: 99, + IntentLevel: len(qqIntentLevels) - 1, + } + + adapter.adjustSessionAfterIntentClose("cfg-intent-cap", &session) + + if session.SessionID != "" || session.LastSeq != 0 { + t.Fatalf("session should be reset, got id=%q seq=%d", session.SessionID, session.LastSeq) + } + if session.IntentLevel != len(qqIntentLevels)-1 { + t.Fatalf("expected capped intent level %d, got %d", len(qqIntentLevels)-1, session.IntentLevel) + } + + saved := adapter.loadSession("cfg-intent-cap") + if saved.IntentLevel != len(qqIntentLevels)-1 { + t.Fatalf("unexpected saved intent level: %d", saved.IntentLevel) + } +} diff --git a/internal/channel/adapters/qq/send.go b/internal/channel/adapters/qq/send.go new file mode 100644 index 00000000..c6d19526 --- /dev/null +++ b/internal/channel/adapters/qq/send.go @@ -0,0 +1,358 @@ +package qq + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "path/filepath" + "regexp" + "strings" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/media" +) + +const ( + qqMediaTypeImage = 1 + qqMediaTypeVideo = 2 + qqMediaTypeVoice = 3 + qqMediaTypeFile = 4 +) + +type qqTargetKind string + +const ( + qqTargetC2C qqTargetKind = "c2c" + qqTargetGroup qqTargetKind = "group" + qqTargetChannel qqTargetKind = "channel" +) + +type qqTarget struct { + Kind qqTargetKind + ID string +} + +var qqUUIDTargetPattern = regexp.MustCompile(`(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + +type attachmentUpload struct { + Base64 string + FileName string + Mime string +} + +func (a *QQAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { + parsed, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + resolvedTarget, err := a.resolveTarget(ctx, msg.Target) + if err != nil { + return err + } + target, err := parseTarget(resolvedTarget) + if err != nil { + return err + } + client := a.getOrCreateClient(cfg, parsed) + replyTo := "" + if msg.Message.Reply != nil { + replyTo = strings.TrimSpace(msg.Message.Reply.MessageID) + } + + text := strings.TrimSpace(msg.Message.PlainText()) + if text != "" { + useMarkdown := parsed.MarkdownSupport && msg.Message.Format == channel.MessageFormatMarkdown && target.Kind != qqTargetChannel + if err := client.sendText(ctx, target, text, replyTo, useMarkdown); err != nil { + return err + } + } + + for _, att := range msg.Message.Attachments { + if err := a.sendAttachment(ctx, cfg, client, target, replyTo, att); err != nil { + return err + } + } + return nil +} + +func parseTarget(raw string) (qqTarget, error) { + normalized := normalizeTarget(raw) + switch { + case strings.HasPrefix(normalized, "c2c:"): + id := strings.TrimSpace(strings.TrimPrefix(normalized, "c2c:")) + if id == "" { + return qqTarget{}, errors.New("qq target c2c id is required") + } + if err := validateQQC2CTarget(id); err != nil { + return qqTarget{}, err + } + return qqTarget{Kind: qqTargetC2C, ID: id}, nil + case strings.HasPrefix(normalized, "group:"): + id := strings.TrimSpace(strings.TrimPrefix(normalized, "group:")) + if id == "" { + return qqTarget{}, errors.New("qq target group id is required") + } + return qqTarget{Kind: qqTargetGroup, ID: id}, nil + case strings.HasPrefix(normalized, "channel:"): + id := strings.TrimSpace(strings.TrimPrefix(normalized, "channel:")) + if id == "" { + return qqTarget{}, errors.New("qq target channel id is required") + } + return qqTarget{Kind: qqTargetChannel, ID: id}, nil + default: + return qqTarget{}, errors.New("unsupported qq target") + } +} + +func validateQQC2CTarget(id string) error { + if qqUUIDTargetPattern.MatchString(strings.TrimSpace(id)) { + return errors.New("qq c2c target must be user_openid, not an internal UUID; use c2c:") + } + return nil +} + +func (a *QQAdapter) sendAttachment(ctx context.Context, cfg channel.ChannelConfig, client *qqClient, target qqTarget, replyTo string, att channel.Attachment) error { + if target.Kind == qqTargetChannel { + switch att.Type { + case channel.AttachmentImage, channel.AttachmentGIF: + return errors.New("qq channel does not support image attachments") + case channel.AttachmentVideo: + return errors.New("qq channel does not support video attachments") + case channel.AttachmentVoice, channel.AttachmentAudio: + return errors.New("qq channel does not support voice attachments") + case channel.AttachmentFile, "": + return errors.New("qq channel does not support file attachments") + default: + return fmt.Errorf("unsupported qq attachment type: %s", att.Type) + } + } + + upload, err := a.prepareAttachmentUpload(ctx, cfg.BotID, att) + if err != nil { + return err + } + + switch att.Type { + case channel.AttachmentImage, channel.AttachmentGIF: + fileInfo, err := client.uploadMedia(ctx, target, qqMediaTypeImage, upload.Base64, "") + if err != nil { + return err + } + return client.sendMedia(ctx, target, fileInfo, replyTo, att.Caption) + case channel.AttachmentVideo: + fileInfo, err := client.uploadMedia(ctx, target, qqMediaTypeVideo, upload.Base64, "") + if err != nil { + return err + } + return client.sendMedia(ctx, target, fileInfo, replyTo, att.Caption) + case channel.AttachmentVoice, channel.AttachmentAudio: + if !supportsQQVoiceUpload(att, upload.FileName, upload.Mime) { + return errors.New("qq voice attachments require SILK/WAV/MP3/AMR input") + } + fileInfo, err := client.uploadMedia(ctx, target, qqMediaTypeVoice, upload.Base64, "") + if err != nil { + return err + } + return client.sendMedia(ctx, target, fileInfo, replyTo, att.Caption) + case channel.AttachmentFile, "": + fileInfo, err := client.uploadMedia(ctx, target, qqMediaTypeFile, upload.Base64, upload.FileName) + if err != nil { + return err + } + return client.sendMedia(ctx, target, fileInfo, replyTo, att.Caption) + default: + return fmt.Errorf("unsupported qq attachment type: %s", att.Type) + } +} + +func (a *QQAdapter) prepareAttachmentUpload(ctx context.Context, fallbackBotID string, att channel.Attachment) (attachmentUpload, error) { + if remoteURL := strings.TrimSpace(att.URL); strings.HasPrefix(strings.ToLower(remoteURL), "http://") || strings.HasPrefix(strings.ToLower(remoteURL), "https://") { + return a.prepareRemoteAttachmentUpload(ctx, att, remoteURL) + } + + if rawBase64 := extractRawBase64(att); rawBase64 != "" { + return attachmentUpload{ + Base64: rawBase64, + FileName: deriveAttachmentName(att), + Mime: strings.TrimSpace(att.Mime), + }, nil + } + + contentHash := strings.TrimSpace(att.ContentHash) + if contentHash == "" || a.assets == nil { + return attachmentUpload{}, errors.New("qq attachment requires http(s) URL, base64, or content_hash") + } + + botID := strings.TrimSpace(fallbackBotID) + if att.Metadata != nil { + if override, ok := att.Metadata["bot_id"].(string); ok && strings.TrimSpace(override) != "" { + botID = strings.TrimSpace(override) + } + } + if botID == "" { + return attachmentUpload{}, errors.New("qq attachment content_hash requires bot_id context") + } + + reader, asset, err := a.assets.Open(ctx, botID, contentHash) + if err != nil { + return attachmentUpload{}, err + } + defer func() { _ = reader.Close() }() + + data, err := media.ReadAllWithLimit(reader, media.MaxAssetBytes) + if err != nil { + return attachmentUpload{}, err + } + + fileName := deriveAttachmentName(att) + if fileName == "" { + fileName = deriveFileNameFromMime(asset.Mime, att.Type) + } + return attachmentUpload{ + Base64: base64.StdEncoding.EncodeToString(data), + FileName: fileName, + Mime: strings.TrimSpace(asset.Mime), + }, nil +} + +func (a *QQAdapter) prepareRemoteAttachmentUpload(ctx context.Context, att channel.Attachment, remoteURL string) (attachmentUpload, error) { + u, err := url.Parse(remoteURL) + if err != nil || (u.Scheme != "https" && u.Scheme != "http") || u.Host == "" { + return attachmentUpload{}, fmt.Errorf("invalid attachment url: %s", remoteURL) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, remoteURL, nil) + if err != nil { + return attachmentUpload{}, err + } + resp, err := a.httpClient.Do(req) //nolint:gosec // remote URL is validated to http(s) with non-empty host above + if err != nil { + return attachmentUpload{}, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return attachmentUpload{}, fmt.Errorf("qq attachment fetch failed: status=%d", resp.StatusCode) + } + + data, err := media.ReadAllWithLimit(resp.Body, media.MaxAssetBytes) + if err != nil { + return attachmentUpload{}, err + } + + mimeType := strings.TrimSpace(att.Mime) + if mimeType == "" { + mimeType = strings.TrimSpace(resp.Header.Get("Content-Type")) + if idx := strings.Index(mimeType, ";"); idx >= 0 { + mimeType = strings.TrimSpace(mimeType[:idx]) + } + } + + fileName := deriveAttachmentName(att) + if fileName == "" { + fileName = deriveFileNameFromMime(mimeType, att.Type) + } + + return attachmentUpload{ + Base64: base64.StdEncoding.EncodeToString(data), + FileName: fileName, + Mime: mimeType, + }, nil +} + +func extractRawBase64(att channel.Attachment) string { + if candidate := strings.TrimSpace(att.Base64); candidate != "" { + if strings.HasPrefix(strings.ToLower(candidate), "data:") { + if idx := strings.Index(candidate, ","); idx >= 0 && idx < len(candidate)-1 { + return candidate[idx+1:] + } + return "" + } + return candidate + } + + candidate := strings.TrimSpace(att.URL) + if strings.HasPrefix(strings.ToLower(candidate), "data:") { + if idx := strings.Index(candidate, ","); idx >= 0 && idx < len(candidate)-1 { + return candidate[idx+1:] + } + } + return "" +} + +func deriveAttachmentName(att channel.Attachment) string { + if name := strings.TrimSpace(att.Name); name != "" { + return name + } + if rawURL := strings.TrimSpace(att.URL); rawURL != "" && !strings.HasPrefix(strings.ToLower(rawURL), "data:") { + if base := filepath.Base(rawURL); base != "." && base != "/" && base != "" { + return base + } + } + return deriveFileNameFromMime(att.Mime, att.Type) +} + +func deriveFileNameFromMime(mimeType string, attType channel.AttachmentType) string { + ext := mimeExtension(mimeType) + base := "attachment" + switch attType { + case channel.AttachmentImage, channel.AttachmentGIF: + base = "image" + case channel.AttachmentVideo: + base = "video" + case channel.AttachmentVoice, channel.AttachmentAudio: + base = "audio" + case channel.AttachmentFile: + base = "file" + } + return base + ext +} + +func mimeExtension(mimeType string) string { + switch strings.ToLower(strings.TrimSpace(mimeType)) { + case "image/png": + return ".png" + case "image/jpeg", "image/jpg": + return ".jpg" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "video/mp4": + return ".mp4" + case "audio/mpeg", "audio/mp3": + return ".mp3" + case "audio/wav", "audio/x-wav": + return ".wav" + case "audio/amr": + return ".amr" + case "application/pdf": + return ".pdf" + default: + return "" + } +} + +func supportsQQVoiceUpload(att channel.Attachment, fileName string, resolvedMime string) bool { + check := strings.ToLower(strings.TrimSpace(fileName)) + if check == "" { + check = strings.ToLower(strings.TrimSpace(att.Name)) + } + for _, ext := range []string{".silk", ".slk", ".amr", ".wav", ".mp3"} { + if strings.HasSuffix(check, ext) { + return true + } + } + mimeType := strings.ToLower(strings.TrimSpace(resolvedMime)) + if mimeType == "" { + mimeType = strings.ToLower(strings.TrimSpace(att.Mime)) + } + switch mimeType { + case "audio/silk", "audio/amr", "audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3": + return true + default: + return false + } +} diff --git a/internal/channel/adapters/qq/send_test.go b/internal/channel/adapters/qq/send_test.go new file mode 100644 index 00000000..7cb39e57 --- /dev/null +++ b/internal/channel/adapters/qq/send_test.go @@ -0,0 +1,770 @@ +package qq + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/memohai/memoh/internal/channel" + "github.com/memohai/memoh/internal/media" +) + +func TestQQSendTextReply(t *testing.T) { + t.Parallel() + + var tokenCalls int + var messageBodies []map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + tokenCalls++ + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/v2/users/user-openid/messages": + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode message body: %v", err) + } + messageBodies = append(messageBodies, body) + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-1"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-1", + BotID: "bot-1", + Credentials: map[string]any{ + "appId": "1024", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "c2c:user-openid", + Message: channel.Message{ + Text: "hello", + Reply: &channel.ReplyRef{MessageID: "source-msg"}, + }, + }) + if err != nil { + t.Fatalf("send text: %v", err) + } + + if tokenCalls != 1 { + t.Fatalf("unexpected token calls: %d", tokenCalls) + } + if len(messageBodies) != 1 { + t.Fatalf("unexpected message calls: %d", len(messageBodies)) + } + if messageBodies[0]["msg_id"] != "source-msg" { + t.Fatalf("unexpected msg_id: %#v", messageBodies[0]["msg_id"]) + } + if messageBodies[0]["msg_type"] != float64(0) { + t.Fatalf("unexpected msg_type: %#v", messageBodies[0]["msg_type"]) + } + if messageBodies[0]["content"] != "hello" { + t.Fatalf("unexpected content: %#v", messageBodies[0]["content"]) + } +} + +func TestQQSendImageAttachment(t *testing.T) { + t.Parallel() + + var uploadBody map[string]any + var messageBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/v2/groups/group-openid/files": + if err := json.NewDecoder(r.Body).Decode(&uploadBody); err != nil { + t.Fatalf("decode upload body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "file_uuid": "file-uuid-1", + "file_info": "file-info-1", + "ttl": 60, + }) + case "/v2/groups/group-openid/messages": + if err := json.NewDecoder(r.Body).Decode(&messageBody); err != nil { + t.Fatalf("decode message body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-2"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + adapter.SetAssetOpener(&trackingAssetOpener{ + data: []byte("png-bytes"), + asset: media.Asset{ + ContentHash: "hash-1", + BotID: "bot-2", + Mime: "image/png", + }, + }) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-2", + BotID: "bot-2", + Credentials: map[string]any{ + "appId": "2048", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "group:group-openid", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentImage, + ContentHash: "hash-1", + Name: "image.png", + }}, + Reply: &channel.ReplyRef{MessageID: "source-msg"}, + }, + }) + if err != nil { + t.Fatalf("send attachment: %v", err) + } + + if uploadBody["file_type"] != float64(qqMediaTypeImage) { + t.Fatalf("unexpected file_type: %#v", uploadBody["file_type"]) + } + if uploadBody["file_data"] != base64.StdEncoding.EncodeToString([]byte("png-bytes")) { + t.Fatalf("unexpected file_data: %#v", uploadBody["file_data"]) + } + if _, ok := uploadBody["file_name"]; ok { + t.Fatalf("unexpected file_name for image upload: %#v", uploadBody["file_name"]) + } + if messageBody["msg_type"] != float64(7) { + t.Fatalf("unexpected msg_type: %#v", messageBody["msg_type"]) + } + if messageBody["msg_id"] != "source-msg" { + t.Fatalf("unexpected msg_id: %#v", messageBody["msg_id"]) + } + media, ok := messageBody["media"].(map[string]any) + if !ok { + t.Fatalf("expected media payload: %#v", messageBody["media"]) + } + if media["file_info"] != "file-info-1" { + t.Fatalf("unexpected media.file_info: %#v", media["file_info"]) + } + if len(media) != 1 { + t.Fatalf("unexpected media payload size: %#v", media) + } +} + +func TestQQSendImageAttachmentCaptionUsesMediaContent(t *testing.T) { + t.Parallel() + + var uploadBody map[string]any + var messageBody map[string]any + var messageCalls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/v2/users/user-openid/files": + if err := json.NewDecoder(r.Body).Decode(&uploadBody); err != nil { + t.Fatalf("decode upload body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "file_uuid": "file-uuid-2", + "file_info": "file-info-2", + "ttl": 60, + }) + case "/v2/users/user-openid/messages": + messageCalls++ + if err := json.NewDecoder(r.Body).Decode(&messageBody); err != nil { + t.Fatalf("decode message body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-2b"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-2b", + BotID: "bot-2b", + Credentials: map[string]any{ + "appId": "2049", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "c2c:user-openid", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentImage, + Base64: "data:image/png;base64,cG5nLWJ5dGVz", + Caption: "test.jpg from QQ", + }}, + }, + }) + if err != nil { + t.Fatalf("send attachment with caption: %v", err) + } + + if uploadBody["file_type"] != float64(qqMediaTypeImage) { + t.Fatalf("unexpected file_type: %#v", uploadBody["file_type"]) + } + if uploadBody["file_data"] != "cG5nLWJ5dGVz" { + t.Fatalf("unexpected file_data: %#v", uploadBody["file_data"]) + } + if messageCalls != 1 { + t.Fatalf("unexpected message calls: %d", messageCalls) + } + if messageBody["msg_type"] != float64(7) { + t.Fatalf("unexpected msg_type: %#v", messageBody["msg_type"]) + } + if messageBody["content"] != "test.jpg from QQ" { + t.Fatalf("unexpected content: %#v", messageBody["content"]) + } +} + +func TestQQProcessingStartedSendsInputHintForDirectMessages(t *testing.T) { + t.Parallel() + + var messageBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/v2/users/user-openid/messages": + if err := json.NewDecoder(r.Body).Decode(&messageBody); err != nil { + t.Fatalf("decode notify body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-3"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + _, err := adapter.ProcessingStarted(context.Background(), channel.ChannelConfig{ + ID: "cfg-3", + Credentials: map[string]any{ + "appId": "4096", + "clientSecret": "secret", + "enableInputHint": true, + }, + }, channel.InboundMessage{}, channel.ProcessingStatusInfo{ + ReplyTarget: "c2c:user-openid", + SourceMessageID: "source-msg", + }) + if err != nil { + t.Fatalf("processing started: %v", err) + } + if messageBody["msg_type"] != float64(6) { + t.Fatalf("unexpected msg_type: %#v", messageBody["msg_type"]) + } + if messageBody["msg_id"] != "source-msg" { + t.Fatalf("unexpected msg_id: %#v", messageBody["msg_id"]) + } +} + +func TestQQSendChannelImageIsUnsupported(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.NotFoundHandler()) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-4", + BotID: "bot-4", + Credentials: map[string]any{ + "appId": "8192", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "channel:channel-1", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentImage, + URL: "https://example.com/output.png", + }}, + }, + }) + if err == nil { + t.Fatal("expected channel image error") + } + if !strings.Contains(err.Error(), "does not support image attachments") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestQQSendChannelReplyIncludesMessageReference(t *testing.T) { + t.Parallel() + + var messageBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/channels/channel-1/messages": + if err := json.NewDecoder(r.Body).Decode(&messageBody); err != nil { + t.Fatalf("decode channel body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-6"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-5", + BotID: "bot-5", + Credentials: map[string]any{ + "appId": "16384", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "channel:channel-1", + Message: channel.Message{ + Text: "hello", + Reply: &channel.ReplyRef{MessageID: "source-msg"}, + }, + }) + if err != nil { + t.Fatalf("send channel reply: %v", err) + } + + if messageBody["msg_id"] != "source-msg" { + t.Fatalf("unexpected msg_id: %#v", messageBody["msg_id"]) + } + ref, ok := messageBody["message_reference"].(map[string]any) + if !ok { + t.Fatalf("expected message_reference: %#v", messageBody["message_reference"]) + } + if ref["message_id"] != "source-msg" { + t.Fatalf("unexpected message_reference.message_id: %#v", ref["message_id"]) + } +} + +func TestQQSendGroupFileUsesNativeUpload(t *testing.T) { + t.Parallel() + + var uploadBody map[string]any + var messageBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/v2/groups/group-openid/files": + if err := json.NewDecoder(r.Body).Decode(&uploadBody); err != nil { + t.Fatalf("decode upload body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "file_uuid": "file-uuid-7", + "file_info": "file-info-7", + "ttl": 60, + }) + case "/v2/groups/group-openid/messages": + if err := json.NewDecoder(r.Body).Decode(&messageBody); err != nil { + t.Fatalf("decode message body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-7"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-6", + BotID: "bot-6", + Credentials: map[string]any{ + "appId": "32768", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "group:group-openid", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentFile, + Base64: "JVBERi0xLjQ=", + Name: "report.pdf", + }}, + Reply: &channel.ReplyRef{MessageID: "source-msg"}, + }, + }) + if err != nil { + t.Fatalf("send group file: %v", err) + } + + if uploadBody["file_type"] != float64(qqMediaTypeFile) { + t.Fatalf("unexpected file_type: %#v", uploadBody["file_type"]) + } + if uploadBody["file_name"] != "report.pdf" { + t.Fatalf("unexpected file_name: %#v", uploadBody["file_name"]) + } + if uploadBody["file_data"] != "JVBERi0xLjQ=" { + t.Fatalf("unexpected file_data: %#v", uploadBody["file_data"]) + } + if messageBody["msg_type"] != float64(7) { + t.Fatalf("unexpected msg_type: %#v", messageBody["msg_type"]) + } + if messageBody["msg_id"] != "source-msg" { + t.Fatalf("unexpected msg_id: %#v", messageBody["msg_id"]) + } + media, ok := messageBody["media"].(map[string]any) + if !ok { + t.Fatalf("expected media payload: %#v", messageBody["media"]) + } + if media["file_info"] != "file-info-7" { + t.Fatalf("unexpected media.file_info: %#v", media["file_info"]) + } +} + +func TestQQSendChannelFileIsUnsupported(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.NotFoundHandler()) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-7", + BotID: "bot-7", + Credentials: map[string]any{ + "appId": "65536", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "channel:channel-1", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentFile, + URL: "https://example.com/files/report.pdf", + Name: "report.pdf", + }}, + Reply: &channel.ReplyRef{MessageID: "source-msg"}, + }, + }) + if err == nil { + t.Fatal("expected channel file error") + } + if !strings.Contains(err.Error(), "does not support file attachments") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestQQSendImageWithLocalPathFailsBeforeAPI(t *testing.T) { + t.Parallel() + + var tokenCalls int + var fileCalls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + tokenCalls++ + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/v2/groups/group-openid/files": + fileCalls++ + _ = json.NewEncoder(w).Encode(map[string]any{"file_info": "unused"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-8", + BotID: "bot-8", + Credentials: map[string]any{ + "appId": "131072", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "group:group-openid", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentImage, + URL: "/tmp/output.png", + Name: "output.png", + }}, + }, + }) + if err == nil { + t.Fatal("expected local path error") + } + if !strings.Contains(err.Error(), "requires http(s) URL, base64, or content_hash") { + t.Fatalf("unexpected error: %v", err) + } + if tokenCalls != 0 { + t.Fatalf("unexpected token calls: %d", tokenCalls) + } + if fileCalls != 0 { + t.Fatalf("unexpected file upload calls: %d", fileCalls) + } +} + +func TestQQSendChannelImageFromStoredAssetIsUnsupported(t *testing.T) { + t.Parallel() + + var tokenCalls int + var messageCalls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + tokenCalls++ + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/channels/channel-1/messages": + messageCalls++ + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-5"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + opener := &trackingAssetOpener{ + data: []byte("png-bytes"), + asset: media.Asset{ + ContentHash: "hash-1", + BotID: "bot-4", + Mime: "image/png", + }, + } + adapter.SetAssetOpener(opener) + + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-4", + BotID: "bot-4", + Credentials: map[string]any{ + "appId": "8192", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "channel:channel-1", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentImage, + ContentHash: "hash-1", + Name: "image.png", + }}, + }, + }) + if err == nil { + t.Fatal("expected channel image error") + } + if !strings.Contains(err.Error(), "does not support image attachments") { + t.Fatalf("unexpected error: %v", err) + } + if opener.called { + t.Fatal("expected stored asset opener to be skipped for channel images") + } + if tokenCalls != 0 { + t.Fatalf("unexpected token calls: %d", tokenCalls) + } + if messageCalls != 0 { + t.Fatalf("unexpected channel message calls: %d", messageCalls) + } +} + +func TestQQSendVoiceAttachmentFromHTTPURLUsesDetectedMime(t *testing.T) { + t.Parallel() + + const voiceBytes = "remote-voice-bytes" + + var uploadBody map[string]any + var messageBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/remote/voice": + w.Header().Set("Content-Type", "audio/mpeg") + _, _ = w.Write([]byte(voiceBytes)) + case "/v2/groups/group-openid/files": + if err := json.NewDecoder(r.Body).Decode(&uploadBody); err != nil { + t.Fatalf("decode upload body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "file_uuid": "file-uuid-voice", + "file_info": "file-info-voice", + "ttl": 60, + }) + case "/v2/groups/group-openid/messages": + if err := json.NewDecoder(r.Body).Decode(&messageBody); err != nil { + t.Fatalf("decode message body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-voice"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-voice", + BotID: "bot-voice", + Credentials: map[string]any{ + "appId": "524288", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "group:group-openid", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentVoice, + URL: server.URL + "/remote/voice", + }}, + }, + }) + if err != nil { + t.Fatalf("send remote voice attachment: %v", err) + } + if uploadBody["file_type"] != float64(qqMediaTypeVoice) { + t.Fatalf("unexpected file_type: %#v", uploadBody["file_type"]) + } + if uploadBody["file_data"] != base64.StdEncoding.EncodeToString([]byte(voiceBytes)) { + t.Fatalf("unexpected file_data: %#v", uploadBody["file_data"]) + } + mediaPayload, ok := messageBody["media"].(map[string]any) + if !ok { + t.Fatalf("expected media payload: %#v", messageBody["media"]) + } + if mediaPayload["file_info"] != "file-info-voice" { + t.Fatalf("unexpected media.file_info: %#v", mediaPayload["file_info"]) + } +} + +func TestQQSendImageAttachmentFromHTTPURLUsesFetchedBytes(t *testing.T) { + t.Parallel() + + const imageBytes = "remote-image-bytes" + + var uploadBody map[string]any + var messageBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/app/getAppAccessToken": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "token-1", + "expires_in": 7200, + }) + case "/remote/test.jpg": + w.Header().Set("Content-Type", "image/jpeg") + _, _ = w.Write([]byte(imageBytes)) + case "/v2/groups/group-openid/files": + if err := json.NewDecoder(r.Body).Decode(&uploadBody); err != nil { + t.Fatalf("decode upload body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "file_uuid": "file-uuid-9", + "file_info": "file-info-9", + "ttl": 60, + }) + case "/v2/groups/group-openid/messages": + if err := json.NewDecoder(r.Body).Decode(&messageBody); err != nil { + t.Fatalf("decode message body: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "m-9"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + adapter := newTestQQAdapter(server) + err := adapter.Send(context.Background(), channel.ChannelConfig{ + ID: "cfg-9", + BotID: "bot-9", + Credentials: map[string]any{ + "appId": "262144", + "clientSecret": "secret", + }, + }, channel.OutboundMessage{ + Target: "group:group-openid", + Message: channel.Message{ + Attachments: []channel.Attachment{{ + Type: channel.AttachmentImage, + URL: server.URL + "/remote/test.jpg", + Name: "test.jpg", + }}, + }, + }) + if err != nil { + t.Fatalf("send remote image attachment: %v", err) + } + if uploadBody["file_type"] != float64(qqMediaTypeImage) { + t.Fatalf("unexpected file_type: %#v", uploadBody["file_type"]) + } + if uploadBody["file_data"] != base64.StdEncoding.EncodeToString([]byte(imageBytes)) { + t.Fatalf("unexpected file_data: %#v", uploadBody["file_data"]) + } + if _, ok := uploadBody["url"]; ok { + t.Fatalf("unexpected qq native url upload payload: %#v", uploadBody["url"]) + } + mediaPayload, ok := messageBody["media"].(map[string]any) + if !ok { + t.Fatalf("expected media payload: %#v", messageBody["media"]) + } + if mediaPayload["file_info"] != "file-info-9" { + t.Fatalf("unexpected media.file_info: %#v", mediaPayload["file_info"]) + } +} + +type trackingAssetOpener struct { + called bool + data []byte + asset media.Asset +} + +func (t *trackingAssetOpener) Open(context.Context, string, string) (io.ReadCloser, media.Asset, error) { + t.called = true + return io.NopCloser(bytes.NewReader(t.data)), t.asset, nil +} + +func newTestQQAdapter(server *httptest.Server) *QQAdapter { + adapter := NewQQAdapter(nil) + adapter.httpClient = server.Client() + adapter.apiBaseURL = server.URL + adapter.tokenURL = server.URL + "/app/getAppAccessToken" + return adapter +} diff --git a/internal/channel/adapters/qq/stream.go b/internal/channel/adapters/qq/stream.go new file mode 100644 index 00000000..9677f55f --- /dev/null +++ b/internal/channel/adapters/qq/stream.go @@ -0,0 +1,152 @@ +package qq + +import ( + "context" + "errors" + "strings" + "sync" + "sync/atomic" + + "github.com/memohai/memoh/internal/channel" +) + +type qqOutboundStream struct { + target string + reply *channel.ReplyRef + send func(context.Context, channel.OutboundMessage) error + + closed atomic.Bool + mu sync.Mutex + buffer strings.Builder + attachments []channel.Attachment + sentText bool +} + +func (a *QQAdapter) OpenStream(_ context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + return &qqOutboundStream{ + target: target, + reply: opts.Reply, + send: func(ctx context.Context, msg channel.OutboundMessage) error { + if msg.Target == "" { + msg.Target = target + } + if msg.Message.Reply == nil && opts.Reply != nil { + msg.Message.Reply = opts.Reply + } + return a.Send(ctx, cfg, msg) + }, + }, nil +} + +func (s *qqOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s == nil || s.send == nil { + return errors.New("qq stream not configured") + } + if s.closed.Load() { + return errors.New("qq stream is closed") + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + switch event.Type { + case channel.StreamEventStatus, + channel.StreamEventPhaseStart, + channel.StreamEventPhaseEnd, + channel.StreamEventToolCallStart, + channel.StreamEventToolCallEnd, + channel.StreamEventAgentStart, + channel.StreamEventAgentEnd, + channel.StreamEventProcessingStarted, + channel.StreamEventProcessingCompleted, + channel.StreamEventProcessingFailed: + return nil + case channel.StreamEventDelta: + if event.Phase == channel.StreamPhaseReasoning || event.Delta == "" { + return nil + } + s.mu.Lock() + s.buffer.WriteString(event.Delta) + s.mu.Unlock() + return nil + case channel.StreamEventAttachment: + if len(event.Attachments) == 0 { + return nil + } + s.mu.Lock() + s.attachments = append(s.attachments, event.Attachments...) + s.mu.Unlock() + return nil + case channel.StreamEventError: + errText := strings.TrimSpace(event.Error) + if errText == "" { + return nil + } + return s.flush(ctx, channel.Message{ + Text: "Error: " + errText, + }) + case channel.StreamEventFinal: + if event.Final == nil { + return errors.New("qq stream final payload is required") + } + return s.flush(ctx, event.Final.Message) + default: + return nil + } +} + +func (s *qqOutboundStream) Close(ctx context.Context) error { + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil +} + +func (s *qqOutboundStream) flush(ctx context.Context, msg channel.Message) error { + s.mu.Lock() + bufferedText := strings.TrimSpace(s.buffer.String()) + bufferedAttachments := append([]channel.Attachment(nil), s.attachments...) + alreadySentText := s.sentText + s.buffer.Reset() + s.attachments = nil + s.mu.Unlock() + + if bufferedText != "" { + msg.Text = bufferedText + msg.Parts = nil + if msg.Format == "" { + msg.Format = channel.MessageFormatPlain + } + } else if alreadySentText && len(bufferedAttachments) == 0 && len(msg.Attachments) == 0 && strings.TrimSpace(msg.PlainText()) != "" { + return nil + } + if len(bufferedAttachments) > 0 { + msg.Attachments = append(bufferedAttachments, msg.Attachments...) + } + if msg.Reply == nil && s.reply != nil { + msg.Reply = s.reply + } + if msg.IsEmpty() { + return nil + } + if err := s.send(ctx, channel.OutboundMessage{ + Target: s.target, + Message: msg, + }); err != nil { + return err + } + if strings.TrimSpace(msg.PlainText()) != "" { + s.mu.Lock() + s.sentText = true + s.mu.Unlock() + } + return nil +} diff --git a/internal/channel/adapters/qq/stream_test.go b/internal/channel/adapters/qq/stream_test.go new file mode 100644 index 00000000..57d0ad97 --- /dev/null +++ b/internal/channel/adapters/qq/stream_test.go @@ -0,0 +1,171 @@ +package qq + +import ( + "context" + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +func TestQQOutboundStreamFlushesBufferedTextOnFinal(t *testing.T) { + t.Parallel() + + var sent []channel.OutboundMessage + stream := &qqOutboundStream{ + target: "c2c:user-openid", + send: func(_ context.Context, msg channel.OutboundMessage) error { + sent = append(sent, msg) + return nil + }, + } + + ctx := context.Background() + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventStatus, Status: channel.StreamStatusStarted}); err != nil { + t.Fatalf("push status: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "Hi "}); err != nil { + t.Fatalf("push delta1: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "there"}); err != nil { + t.Fatalf("push delta2: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{}}); err != nil { + t.Fatalf("push final: %v", err) + } + + if len(sent) != 1 { + t.Fatalf("expected one send, got %d", len(sent)) + } + if sent[0].Target != "c2c:user-openid" { + t.Fatalf("unexpected target: %s", sent[0].Target) + } + if sent[0].Message.PlainText() != "Hi there" { + t.Fatalf("unexpected text: %q", sent[0].Message.PlainText()) + } +} + +func TestQQOutboundStreamFinalUsesExplicitMessageAndBufferedAttachments(t *testing.T) { + t.Parallel() + + var sent []channel.OutboundMessage + stream := &qqOutboundStream{ + target: "group:group-openid", + send: func(_ context.Context, msg channel.OutboundMessage) error { + sent = append(sent, msg) + return nil + }, + } + + ctx := context.Background() + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventAttachment, + Attachments: []channel.Attachment{{Type: channel.AttachmentImage, URL: "https://example.com/a.png"}}, + }); err != nil { + t.Fatalf("push attachment: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventFinal, + Final: &channel.StreamFinalizePayload{Message: channel.Message{ + Text: "done", + }}, + }); err != nil { + t.Fatalf("push final: %v", err) + } + + if len(sent) != 1 { + t.Fatalf("expected one send, got %d", len(sent)) + } + if sent[0].Message.PlainText() != "done" { + t.Fatalf("unexpected text: %q", sent[0].Message.PlainText()) + } + if len(sent[0].Message.Attachments) != 1 { + t.Fatalf("unexpected attachments: %d", len(sent[0].Message.Attachments)) + } +} + +func TestQQOutboundStreamFinalPrefersBufferedVisibleText(t *testing.T) { + t.Parallel() + + var sent []channel.OutboundMessage + stream := &qqOutboundStream{ + target: "c2c:user-openid", + send: func(_ context.Context, msg channel.OutboundMessage) error { + sent = append(sent, msg) + return nil + }, + } + + ctx := context.Background() + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "visible "}); err != nil { + t.Fatalf("push delta1: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "answer"}); err != nil { + t.Fatalf("push delta2: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventFinal, + Final: &channel.StreamFinalizePayload{Message: channel.Message{ + Text: "internal trace\nvisible answer", + }}, + }); err != nil { + t.Fatalf("push final: %v", err) + } + + if len(sent) != 1 { + t.Fatalf("expected one send, got %d", len(sent)) + } + if got := sent[0].Message.PlainText(); got != "visible answer" { + t.Fatalf("unexpected text: %q", got) + } +} + +func TestQQOutboundStreamIgnoresLaterTextOnlyFinalAfterBufferedReply(t *testing.T) { + t.Parallel() + + var sent []channel.OutboundMessage + stream := &qqOutboundStream{ + target: "c2c:user-openid", + send: func(_ context.Context, msg channel.OutboundMessage) error { + sent = append(sent, msg) + return nil + }, + } + + ctx := context.Background() + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventDelta, Delta: "visible answer"}); err != nil { + t.Fatalf("push delta: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{Type: channel.StreamEventFinal, Final: &channel.StreamFinalizePayload{}}); err != nil { + t.Fatalf("push first final: %v", err) + } + if err := stream.Push(ctx, channel.StreamEvent{ + Type: channel.StreamEventFinal, + Final: &channel.StreamFinalizePayload{Message: channel.Message{ + Text: "我需要按照用户的要求,在工具调用后完整复述。", + }}, + }); err != nil { + t.Fatalf("push second final: %v", err) + } + + if len(sent) != 1 { + t.Fatalf("expected 1 outbound message, got %d", len(sent)) + } + if got := sent[0].Message.PlainText(); got != "visible answer" { + t.Fatalf("unexpected text: %q", got) + } +} + +func TestQQOutboundStreamRejectsAfterClose(t *testing.T) { + t.Parallel() + + stream := &qqOutboundStream{} + if err := stream.Close(context.Background()); err != nil { + t.Fatalf("close: %v", err) + } + if err := stream.Push(context.Background(), channel.StreamEvent{ + Type: channel.StreamEventDelta, + Delta: "x", + }); err == nil { + t.Fatal("expected closed error") + } +} diff --git a/internal/channel/adapters/qq/target_resolver.go b/internal/channel/adapters/qq/target_resolver.go new file mode 100644 index 00000000..5f7d8e5f --- /dev/null +++ b/internal/channel/adapters/qq/target_resolver.go @@ -0,0 +1,136 @@ +package qq + +import ( + "context" + "errors" + "regexp" + "strings" + + "github.com/jackc/pgx/v5" + + identitypkg "github.com/memohai/memoh/internal/channel/identities" +) + +var qqOpenIDPattern = regexp.MustCompile(`(?i)^[0-9a-f]{32}$`) + +func (a *QQAdapter) resolveTarget(ctx context.Context, raw string) (string, error) { + target := normalizeTarget(raw) + if !strings.HasPrefix(target, "c2c:") { + return target, nil + } + id := strings.TrimSpace(strings.TrimPrefix(target, "c2c:")) + if !qqUUIDTargetPattern.MatchString(id) { + return target, nil + } + if mapped, found, err := a.resolveRouteTarget(ctx, id); err != nil { + return "", err + } else if found { + return normalizeTarget(mapped), nil + } + if mapped, found, err := a.resolveIdentityTarget(ctx, id); err != nil { + return "", err + } else if found { + return normalizeTarget(mapped), nil + } + return target, nil +} + +func (a *QQAdapter) resolveRouteTarget(ctx context.Context, routeID string) (string, bool, error) { + resolver := a.getRouteResolver() + if resolver == nil { + return "", false, nil + } + item, err := resolver.GetByID(ctx, routeID) + if err != nil { + if isQQLookupMiss(err) { + return "", false, nil + } + return "", false, err + } + if !strings.EqualFold(strings.TrimSpace(item.Platform), string(Type)) { + return "", false, nil + } + target := strings.TrimSpace(item.ReplyTarget) + if target == "" { + return "", false, nil + } + return target, true, nil +} + +func (a *QQAdapter) resolveIdentityTarget(ctx context.Context, id string) (string, bool, error) { + resolver := a.getIdentityResolver() + if resolver == nil { + return "", false, nil + } + if mapped, found, err := lookupQQIdentityTarget(ctx, resolver.ListCanonicalChannelIdentities, id); err != nil { + return "", false, err + } else if found { + return mapped, true, nil + } + if mapped, found, err := lookupQQIdentityTarget(ctx, resolver.ListUserChannelIdentities, id); err != nil { + return "", false, err + } else if found { + return mapped, true, nil + } + item, err := resolver.GetByID(ctx, id) + if err != nil { + if isQQLookupMiss(err) { + return "", false, nil + } + return "", false, err + } + if mapped := qqIdentityTarget(item); mapped != "" { + return mapped, true, nil + } + return "", false, nil +} + +func lookupQQIdentityTarget(ctx context.Context, lookup func(context.Context, string) ([]identitypkg.ChannelIdentity, error), id string) (string, bool, error) { + items, err := lookup(ctx, id) + if err != nil { + if isQQLookupMiss(err) { + return "", false, nil + } + return "", false, err + } + if mapped := firstQQIdentityTarget(items); mapped != "" { + return mapped, true, nil + } + return "", false, nil +} + +func firstQQIdentityTarget(items []identitypkg.ChannelIdentity) string { + for _, item := range items { + if target := qqIdentityTarget(item); target != "" { + return target + } + } + return "" +} + +func qqIdentityTarget(item identitypkg.ChannelIdentity) string { + if !strings.EqualFold(strings.TrimSpace(item.Channel), string(Type)) { + return "" + } + subjectID := strings.TrimSpace(item.ChannelSubjectID) + if !qqOpenIDPattern.MatchString(subjectID) { + return "" + } + return "c2c:" + subjectID +} + +func isQQLookupMiss(err error) bool { + return errors.Is(err, pgx.ErrNoRows) || errors.Is(err, identitypkg.ErrChannelIdentityNotFound) +} + +func (a *QQAdapter) getRouteResolver() routeResolver { + a.mu.Lock() + defer a.mu.Unlock() + return a.routes +} + +func (a *QQAdapter) getIdentityResolver() channelIdentityResolver { + a.mu.Lock() + defer a.mu.Unlock() + return a.identity +} diff --git a/internal/channel/adapters/qq/target_resolver_test.go b/internal/channel/adapters/qq/target_resolver_test.go new file mode 100644 index 00000000..870df515 --- /dev/null +++ b/internal/channel/adapters/qq/target_resolver_test.go @@ -0,0 +1,201 @@ +package qq + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + + identitypkg "github.com/memohai/memoh/internal/channel/identities" + routepkg "github.com/memohai/memoh/internal/channel/route" +) + +const testQQOpenID = "00112233445566778899AABBCCDDEEFF" + +func TestQQResolveTargetMapsRouteID(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + adapter.SetRouteResolver(&fakeQQRouteResolver{ + byID: map[string]routepkg.Route{ + "3fe2bad9-3eae-4f23-872c-b7a63662aa00": { + ID: "3fe2bad9-3eae-4f23-872c-b7a63662aa00", + Platform: "qq", + ReplyTarget: "c2c:" + testQQOpenID, + }, + }, + }) + + got, err := adapter.resolveTarget(context.Background(), "3fe2bad9-3eae-4f23-872c-b7a63662aa00") + if err != nil { + t.Fatalf("resolveTarget returned error: %v", err) + } + if got != "c2c:"+testQQOpenID { + t.Fatalf("unexpected mapped target: %q", got) + } +} + +func TestQQResolveTargetMapsIdentityID(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + adapter.SetChannelIdentityResolver(&fakeQQIdentityResolver{ + canonical: map[string][]identitypkg.ChannelIdentity{ + "3fe2bad9-3eae-4f23-872c-b7a63662aa00": { + {ID: "qq-identity-1", Channel: "qq", ChannelSubjectID: testQQOpenID}, + }, + }, + }) + + got, err := adapter.resolveTarget(context.Background(), "3fe2bad9-3eae-4f23-872c-b7a63662aa00") + if err != nil { + t.Fatalf("resolveTarget returned error: %v", err) + } + if got != "c2c:"+testQQOpenID { + t.Fatalf("unexpected mapped target: %q", got) + } +} + +func TestQQResolveTargetMapsUserID(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + adapter.SetChannelIdentityResolver(&fakeQQIdentityResolver{ + userScoped: map[string][]identitypkg.ChannelIdentity{ + "3fe2bad9-3eae-4f23-872c-b7a63662aa00": { + {ID: "qq-identity-1", Channel: "qq", ChannelSubjectID: testQQOpenID}, + }, + }, + }) + + got, err := adapter.resolveTarget(context.Background(), "3fe2bad9-3eae-4f23-872c-b7a63662aa00") + if err != nil { + t.Fatalf("resolveTarget returned error: %v", err) + } + if got != "c2c:"+testQQOpenID { + t.Fatalf("unexpected mapped target: %q", got) + } +} + +func TestQQResolveTargetSkipsNonOpenIDQQIdentity(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + adapter.SetChannelIdentityResolver(&fakeQQIdentityResolver{ + canonical: map[string][]identitypkg.ChannelIdentity{ + "3fe2bad9-3eae-4f23-872c-b7a63662aa00": { + {ID: "qq-guild-identity-1", Channel: "qq", ChannelSubjectID: "guild-user-id"}, + }, + }, + }) + + got, err := adapter.resolveTarget(context.Background(), "3fe2bad9-3eae-4f23-872c-b7a63662aa00") + if err != nil { + t.Fatalf("resolveTarget returned error: %v", err) + } + if got != "c2c:3fe2bad9-3eae-4f23-872c-b7a63662aa00" { + t.Fatalf("unexpected mapped target: %q", got) + } +} + +func TestQQResolveTargetReturnsRouteResolverErrors(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + adapter.SetRouteResolver(&fakeQQRouteResolver{err: errors.New("route store unavailable")}) + + _, err := adapter.resolveTarget(context.Background(), "3fe2bad9-3eae-4f23-872c-b7a63662aa00") + if err == nil { + t.Fatal("expected route resolver error") + } + if !strings.Contains(err.Error(), "route store unavailable") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestQQResolveTargetReturnsIdentityResolverErrors(t *testing.T) { + t.Parallel() + + adapter := NewQQAdapter(nil) + adapter.SetChannelIdentityResolver(&fakeQQIdentityResolver{canonicalErr: errors.New("identity store unavailable")}) + + _, err := adapter.resolveTarget(context.Background(), "3fe2bad9-3eae-4f23-872c-b7a63662aa00") + if err == nil { + t.Fatal("expected identity resolver error") + } + if !strings.Contains(err.Error(), "identity store unavailable") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParseTargetRejectsUUIDForC2C(t *testing.T) { + t.Parallel() + + _, err := parseTarget("3fe2bad9-3eae-4f23-872c-b7a63662aa00") + if err == nil { + t.Fatal("expected c2c uuid target error") + } + if !strings.Contains(err.Error(), "user_openid") { + t.Fatalf("unexpected error: %v", err) + } +} + +type fakeQQIdentityResolver struct { + byID map[string]identitypkg.ChannelIdentity + canonical map[string][]identitypkg.ChannelIdentity + userScoped map[string][]identitypkg.ChannelIdentity + byIDErr error + canonicalErr error + userErr error +} + +func (f *fakeQQIdentityResolver) GetByID(_ context.Context, channelIdentityID string) (identitypkg.ChannelIdentity, error) { + if f.byIDErr != nil { + return identitypkg.ChannelIdentity{}, f.byIDErr + } + item, ok := f.byID[channelIdentityID] + if !ok { + return identitypkg.ChannelIdentity{}, identitypkg.ErrChannelIdentityNotFound + } + return item, nil +} + +func (f *fakeQQIdentityResolver) ListCanonicalChannelIdentities(_ context.Context, channelIdentityID string) ([]identitypkg.ChannelIdentity, error) { + if f.canonicalErr != nil { + return nil, f.canonicalErr + } + items, ok := f.canonical[channelIdentityID] + if !ok { + return nil, identitypkg.ErrChannelIdentityNotFound + } + return items, nil +} + +func (f *fakeQQIdentityResolver) ListUserChannelIdentities(_ context.Context, userID string) ([]identitypkg.ChannelIdentity, error) { + if f.userErr != nil { + return nil, f.userErr + } + items, ok := f.userScoped[userID] + if !ok { + return nil, nil + } + return items, nil +} + +type fakeQQRouteResolver struct { + byID map[string]routepkg.Route + err error +} + +func (f *fakeQQRouteResolver) GetByID(_ context.Context, routeID string) (routepkg.Route, error) { + if f.err != nil { + return routepkg.Route{}, f.err + } + item, ok := f.byID[routeID] + if !ok { + return routepkg.Route{}, pgx.ErrNoRows + } + return item, nil +} diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index 1f66ec60..ad7b1cbd 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -1480,10 +1480,11 @@ func (p *ChannelInboundProcessor) ingestInboundAttachments( item.Mime = finalMime maxBytes := media.MaxAssetBytes asset, err := p.mediaService.Ingest(ctx, media.IngestInput{ - BotID: botID, - Mime: strings.TrimSpace(item.Mime), - Reader: preparedReader, - MaxBytes: maxBytes, + BotID: botID, + Mime: strings.TrimSpace(item.Mime), + Reader: preparedReader, + MaxBytes: maxBytes, + OriginalExt: filepath.Ext(strings.TrimSpace(item.Name)), }) if payload.reader != nil { _ = payload.reader.Close() diff --git a/internal/channel/inbound/channel_test.go b/internal/channel/inbound/channel_test.go index dbc2febc..cc6d4f32 100644 --- a/internal/channel/inbound/channel_test.go +++ b/internal/channel/inbound/channel_test.go @@ -1,6 +1,7 @@ package inbound import ( + "bytes" "context" "encoding/base64" "encoding/json" @@ -224,15 +225,54 @@ func (*fakeMediaIngestor) AccessPath(asset media.Asset) string { return "/data/media/" + asset.StorageKey } -type fakeAttachmentResolverAdapter struct{} +type fakeStorageProvider struct { + objects map[string][]byte +} -func (*fakeAttachmentResolverAdapter) Type() channel.ChannelType { +func (f *fakeStorageProvider) Put(_ context.Context, key string, reader io.Reader) error { + if f.objects == nil { + f.objects = make(map[string][]byte) + } + payload, err := io.ReadAll(reader) + if err != nil { + return err + } + f.objects[key] = payload + return nil +} + +func (f *fakeStorageProvider) Open(_ context.Context, key string) (io.ReadCloser, error) { + payload, ok := f.objects[key] + if !ok { + return nil, errors.New("not found") + } + return io.NopCloser(bytes.NewReader(payload)), nil +} + +func (f *fakeStorageProvider) Delete(_ context.Context, key string) error { + delete(f.objects, key) + return nil +} + +func (*fakeStorageProvider) AccessPath(key string) string { + return "/data/media/" + key +} + +type fakeAttachmentResolverAdapter struct { + typ channel.ChannelType + payload channel.AttachmentPayload +} + +func (f *fakeAttachmentResolverAdapter) Type() channel.ChannelType { + if f != nil && strings.TrimSpace(f.typ.String()) != "" { + return f.typ + } return channel.ChannelType("resolver-test") } -func (*fakeAttachmentResolverAdapter) Descriptor() channel.Descriptor { +func (f *fakeAttachmentResolverAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ - Type: channel.ChannelType("resolver-test"), + Type: f.Type(), DisplayName: "ResolverTest", Capabilities: channel.ChannelCapabilities{ Text: true, @@ -241,7 +281,10 @@ func (*fakeAttachmentResolverAdapter) Descriptor() channel.Descriptor { } } -func (*fakeAttachmentResolverAdapter) ResolveAttachment(_ context.Context, _ channel.ChannelConfig, _ channel.Attachment) (channel.AttachmentPayload, error) { +func (f *fakeAttachmentResolverAdapter) ResolveAttachment(_ context.Context, _ channel.ChannelConfig, _ channel.Attachment) (channel.AttachmentPayload, error) { + if f != nil && f.payload.Reader != nil { + return f.payload, nil + } return channel.AttachmentPayload{ Reader: io.NopCloser(strings.NewReader("resolver-bytes")), Mime: "application/octet-stream", @@ -591,6 +634,57 @@ func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) { } } +type failingOpenStreamSender struct { + err error +} + +func (*failingOpenStreamSender) Send(_ context.Context, _ channel.OutboundMessage) error { + return nil +} + +func (s *failingOpenStreamSender) OpenStream(_ context.Context, _ string, _ channel.StreamOptions) (channel.OutboundStream, error) { + if s != nil && s.err != nil { + return nil, s.err + } + return nil, errors.New("open stream failed") +} + +func TestChannelInboundProcessorPersistsActiveChatBeforeOpenStream(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-openstream"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-openstream", RouteID: "route-openstream"}} + gateway := &fakeChatGateway{} + processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + sender := &failingOpenStreamSender{err: errors.New("stream unavailable")} + + cfg := channel.ChannelConfig{ID: "cfg-openstream", BotID: "bot-1", ChannelType: channel.ChannelType("qq")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("qq"), + Message: channel.Message{ID: "msg-openstream-1", Text: "hello"}, + ReplyTarget: "c2c:user-openid", + Sender: channel.Identity{SubjectID: "user-1"}, + Conversation: channel.Conversation{ + ID: "conv-openstream", + Type: "p2p", + }, + } + + err := processor.HandleInbound(context.Background(), cfg, msg, sender) + if err == nil || err.Error() != "stream unavailable" { + t.Fatalf("expected open stream error, got: %v", err) + } + if len(chatSvc.persistedIn) != 1 { + t.Fatalf("expected active-chat user turn to be persisted before stream open, got %d", len(chatSvc.persistedIn)) + } + if got := chatSvc.persistedIn[0].ExternalMessageID; got != "msg-openstream-1" { + t.Fatalf("unexpected persisted external_message_id: %q", got) + } + if gateway.gotReq.Query != "" { + t.Fatalf("runner should not be called when stream open fails") + } +} + func TestChannelInboundProcessorPersistsAttachmentAssetRefs(t *testing.T) { channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-asset"}} memberSvc := &fakeMemberService{isMember: true} @@ -785,6 +879,71 @@ func TestChannelInboundProcessorIngestsBase64Attachment(t *testing.T) { } } +func TestChannelInboundProcessorIngestsQQFileAttachmentKeepsOriginalExtWhenMimeGeneric(t *testing.T) { + channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-qq-file"}} + memberSvc := &fakeMemberService{isMember: true} + chatSvc := &fakeChatService{resolveResult: route.ResolveConversationResult{ChatID: "chat-qq-file", RouteID: "route-qq-file"}} + gateway := &fakeChatGateway{ + resp: conversation.ChatResponse{ + Messages: []conversation.ModelMessage{ + {Role: "assistant", Content: conversation.NewTextContent("ok")}, + }, + }, + } + registry := channel.NewRegistry() + registry.MustRegister(&fakeAttachmentResolverAdapter{ + typ: channel.ChannelType("qq"), + payload: channel.AttachmentPayload{ + Reader: io.NopCloser(bytes.NewReader([]byte{0x00, 0x01, 0x02, 0x03, 0x04})), + Mime: "application/octet-stream", + Size: 5, + }, + }) + processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) + storage := &fakeStorageProvider{} + mediaSvc := media.NewService(slog.Default(), storage) + processor.SetMediaService(mediaSvc) + sender := &fakeReplySender{} + + cfg := channel.ChannelConfig{ID: "cfg-qq-file", BotID: "bot-1", ChannelType: channel.ChannelType("qq")} + msg := channel.InboundMessage{ + BotID: "bot-1", + Channel: channel.ChannelType("qq"), + Message: channel.Message{ + ID: "msg-qq-file-1", + Text: "[User sent 1 attachment]", + Attachments: []channel.Attachment{ + { + Type: channel.AttachmentFile, + PlatformKey: "qq-file-1", + Name: "test.md", + Mime: "file", + }, + }, + }, + ReplyTarget: "c2c:user-openid", + Sender: channel.Identity{SubjectID: "qq-user"}, + Conversation: channel.Conversation{ + ID: "qq-user", + Type: "direct", + }, + } + + if err := processor.HandleInbound(context.Background(), cfg, msg, sender); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gateway.gotReq.Attachments) != 1 { + t.Fatalf("expected one attachment in gateway request, got %d", len(gateway.gotReq.Attachments)) + } + storageKey, _ := gateway.gotReq.Attachments[0].Metadata["storage_key"].(string) + if !strings.HasSuffix(storageKey, ".md") { + t.Fatalf("expected storage key to keep .md extension, got %q", storageKey) + } + if strings.HasSuffix(storageKey, ".bin") { + t.Fatalf("expected storage key to avoid .bin fallback, got %q", storageKey) + } +} + func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) { channelIdentitySvc := &fakeChannelIdentityService{channelIdentity: identities.ChannelIdentity{ID: "channelIdentity-member"}} memberSvc := &fakeMemberService{isMember: true} diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go index c92e4b0d..b7e31b4e 100644 --- a/internal/channel/outbound.go +++ b/internal/channel/outbound.go @@ -34,12 +34,13 @@ type Chunker func(text string, limit int) []string // OutboundPolicy configures how outbound messages are chunked, ordered, and retried. type OutboundPolicy struct { - TextChunkLimit int `json:"text_chunk_limit,omitempty"` - ChunkerMode ChunkerMode `json:"chunker_mode,omitempty"` - Chunker Chunker `json:"-"` - MediaOrder OutboundOrder `json:"media_order,omitempty"` - RetryMax int `json:"retry_max,omitempty"` - RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` + TextChunkLimit int `json:"text_chunk_limit,omitempty"` + ChunkerMode ChunkerMode `json:"chunker_mode,omitempty"` + Chunker Chunker `json:"-"` + MediaOrder OutboundOrder `json:"media_order,omitempty"` + InlineTextWithMedia bool `json:"inline_text_with_media,omitempty"` + RetryMax int `json:"retry_max,omitempty"` + RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` } // NormalizeOutboundPolicy fills zero-value fields with sensible defaults. @@ -199,11 +200,16 @@ func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]Outbou return nil, errors.New("message is required") } normalized := normalizeOutboundMessage(msg.Message) + attachments := append([]Attachment(nil), normalized.Attachments...) chunker := policy.Chunker if normalized.Format == MessageFormatMarkdown { chunker = ChunkMarkdownText } base := normalized + if shouldInlineTextWithMedia(policy, base, attachments) { + attachments[0].Caption = strings.TrimSpace(base.Text) + base.Text = "" + } base.Attachments = nil textMessages := make([]OutboundMessage, 0) shouldChunk := policy.TextChunkLimit > 0 && strings.TrimSpace(base.Text) != "" && len(base.Parts) == 0 @@ -238,7 +244,6 @@ func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]Outbou textMessages = append(textMessages, OutboundMessage{Target: msg.Target, Message: base}) } - attachments := normalized.Attachments attachmentMessages := make([]OutboundMessage, 0) if len(attachments) > 0 { media := normalized @@ -259,6 +264,24 @@ func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]Outbou return append(attachmentMessages, textMessages...), nil } +func shouldInlineTextWithMedia(policy OutboundPolicy, msg Message, attachments []Attachment) bool { + if !policy.InlineTextWithMedia { + return false + } + if strings.TrimSpace(msg.Text) == "" || len(msg.Parts) > 0 || len(attachments) == 0 { + return false + } + if strings.TrimSpace(attachments[0].Caption) != "" { + return false + } + switch attachments[0].Type { + case AttachmentImage, AttachmentGIF, AttachmentVideo, AttachmentAudio, AttachmentVoice: + return true + default: + return false + } +} + func normalizeOutboundMessage(msg Message) Message { if msg.Format == "" { if len(msg.Parts) > 0 { diff --git a/internal/channel/outbound_test.go b/internal/channel/outbound_test.go index c0675bea..4fb2e2ec 100644 --- a/internal/channel/outbound_test.go +++ b/internal/channel/outbound_test.go @@ -291,6 +291,40 @@ func TestPushFinalWithChunking_AttachmentsSeparated(t *testing.T) { } } +func TestBuildOutboundMessages_InlineTextWithMediaMovesTextToCaption(t *testing.T) { + t.Parallel() + + msgs, err := buildOutboundMessages(OutboundMessage{ + Target: "chat-1", + Message: Message{ + Text: "test.jpg from QQ", + Attachments: []Attachment{{ + Type: AttachmentImage, + URL: "https://example.com/test.jpg", + }}, + }, + }, OutboundPolicy{ + TextChunkLimit: 100, + MediaOrder: OutboundOrderTextFirst, + InlineTextWithMedia: true, + }) + if err != nil { + t.Fatalf("buildOutboundMessages failed: %v", err) + } + if len(msgs) != 1 { + t.Fatalf("expected 1 outbound message, got %d", len(msgs)) + } + if got := strings.TrimSpace(msgs[0].Message.Text); got != "" { + t.Fatalf("expected inline caption to suppress standalone text, got %q", got) + } + if len(msgs[0].Message.Attachments) != 1 { + t.Fatalf("expected 1 attachment, got %d", len(msgs[0].Message.Attachments)) + } + if got := msgs[0].Message.Attachments[0].Caption; got != "test.jpg from QQ" { + t.Fatalf("unexpected attachment caption: %q", got) + } +} + func TestPushFinalWithChunking_NonFinalPassthrough(t *testing.T) { t.Parallel() stream, rec, sent := newChunkingTestStream(t, 100) diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 6ff720b9..7832d68d 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -359,6 +359,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r return resolvedContext{}, loadErr } loaded = pruneHistoryForGateway(loaded) + loaded = dedupePersistedCurrentUserMessage(loaded, req) messages = trimMessagesByTokens(r.logger, loaded, historyBudget) r.logger.Debug("context trim result", slog.Int("loaded_messages", len(loaded)), @@ -1184,6 +1185,10 @@ type messageWithUsage struct { Message conversation.ModelMessage UsageInputTokens *int UsageOutputTokens *int + RouteID string + ExternalMessageID string + Platform string + SenderChannelID string } func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]messageWithUsage, error) { @@ -1214,11 +1219,55 @@ func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMi outputTokens = u.OutputTokens } } - result = append(result, messageWithUsage{Message: mm, UsageInputTokens: inputTokens, UsageOutputTokens: outputTokens}) + result = append(result, messageWithUsage{ + Message: mm, + UsageInputTokens: inputTokens, + UsageOutputTokens: outputTokens, + RouteID: strings.TrimSpace(m.RouteID), + ExternalMessageID: strings.TrimSpace(m.ExternalMessageID), + Platform: strings.TrimSpace(m.Platform), + SenderChannelID: strings.TrimSpace(m.SenderChannelIdentityID), + }) } return result, nil } +func dedupePersistedCurrentUserMessage(messages []messageWithUsage, req conversation.ChatRequest) []messageWithUsage { + if !req.UserMessagePersisted || len(messages) == 0 { + return messages + } + + targetRouteID := strings.TrimSpace(req.RouteID) + targetExternalID := strings.TrimSpace(req.ExternalMessageID) + targetPlatform := strings.TrimSpace(req.CurrentChannel) + targetSenderChannelID := strings.TrimSpace(req.SourceChannelIdentityID) + if targetExternalID == "" { + return messages + } + + for i := len(messages) - 1; i >= 0; i-- { + item := messages[i] + if !strings.EqualFold(strings.TrimSpace(item.Message.Role), "user") { + continue + } + if strings.TrimSpace(item.ExternalMessageID) != targetExternalID { + continue + } + if targetRouteID != "" && item.RouteID != "" && item.RouteID != targetRouteID { + continue + } + if targetPlatform != "" && item.Platform != "" && !strings.EqualFold(item.Platform, targetPlatform) { + continue + } + if targetSenderChannelID != "" && item.SenderChannelID != "" && item.SenderChannelID != targetSenderChannelID { + continue + } + return append(messages[:i], messages[i+1:]...) + } + + return messages +} + func estimateMessageTokens(msg conversation.ModelMessage) int { text := msg.TextContent() if len(text) == 0 { @@ -1349,6 +1398,11 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.Chat return err } senderChannelIdentityID, senderUserID := r.resolvePersistSenderIDs(ctx, req) + meta := buildRouteMetadata(req) + if meta == nil { + meta = map[string]any{} + } + meta["trigger_mode"] = "active_chat" _, err = r.messageService.Persist(ctx, messagepkg.PersistInput{ BotID: req.BotID, RouteID: req.RouteID, @@ -1358,7 +1412,7 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.Chat ExternalMessageID: req.ExternalMessageID, Role: "user", Content: content, - Metadata: buildRouteMetadata(req), + Metadata: meta, Assets: chatAttachmentsToAssetRefs(req.Attachments), }) return err diff --git a/internal/conversation/flow/resolver_dedupe_test.go b/internal/conversation/flow/resolver_dedupe_test.go new file mode 100644 index 00000000..e79ed6fd --- /dev/null +++ b/internal/conversation/flow/resolver_dedupe_test.go @@ -0,0 +1,45 @@ +package flow + +import ( + "testing" + + "github.com/memohai/memoh/internal/conversation" +) + +func TestDedupePersistedCurrentUserMessageRemovesCurrentInboundFromHistory(t *testing.T) { + t.Parallel() + + history := []messageWithUsage{ + { + Message: conversation.ModelMessage{ + Role: "user", + Content: conversation.NewTextContent("---\nmessage-id: qq-msg-1\nchannel: qq\n---\nhello"), + }, + RouteID: "route-1", + ExternalMessageID: "qq-msg-1", + Platform: "qq", + SenderChannelID: "channel-identity-1", + }, + { + Message: conversation.ModelMessage{ + Role: "assistant", + Content: conversation.NewTextContent("ok"), + }, + }, + } + + got := dedupePersistedCurrentUserMessage(history, conversation.ChatRequest{ + UserMessagePersisted: true, + RouteID: "route-1", + ExternalMessageID: "qq-msg-1", + CurrentChannel: "qq", + SourceChannelIdentityID: "channel-identity-1", + }) + + if len(got) != 1 { + t.Fatalf("expected 1 message after dedupe, got %d", len(got)) + } + if got[0].Message.Role != "assistant" { + t.Fatalf("unexpected remaining role: %s", got[0].Message.Role) + } +} diff --git a/internal/mcp/mcpclient/client.go b/internal/mcp/mcpclient/client.go index f8b9728b..11617189 100644 --- a/internal/mcp/mcpclient/client.go +++ b/internal/mcp/mcpclient/client.go @@ -210,7 +210,7 @@ func (c *Client) ReadRaw(ctx context.Context, path string) (io.ReadCloser, error if err != nil { return nil, mapError(err) } - return &streamReader{stream: stream}, nil + return newStreamReader(stream) } // WriteRaw writes raw bytes to a file in the container. @@ -264,15 +264,40 @@ type streamReader struct { off int } -func (r *streamReader) Read(p []byte) (int, error) { +func newStreamReader(stream pb.ContainerService_ReadRawClient) (io.ReadCloser, error) { + first, err := stream.Recv() + switch { + case errors.Is(err, io.EOF): + return io.NopCloser(bytes.NewReader(nil)), nil + case err != nil: + return nil, mapError(err) + default: + return &streamReader{stream: stream, buf: first.GetData()}, nil + } +} + +func (r *streamReader) fill() error { for r.off >= len(r.buf) { msg, err := r.stream.Recv() if err != nil { - return 0, err + if errors.Is(err, io.EOF) { + return io.EOF + } + return mapError(err) } r.buf = msg.GetData() r.off = 0 } + return nil +} + +func (r *streamReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if err := r.fill(); err != nil { + return 0, err + } n := copy(p, r.buf[r.off:]) r.off += n return n, nil diff --git a/internal/mcp/mcpclient/client_test.go b/internal/mcp/mcpclient/client_test.go new file mode 100644 index 00000000..df75c1e5 --- /dev/null +++ b/internal/mcp/mcpclient/client_test.go @@ -0,0 +1,130 @@ +package mcpclient + +import ( + "context" + "errors" + "io" + "net" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" + + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" +) + +const testBufSize = 1 << 20 + +type rawReadTestServer struct { + pb.UnimplementedContainerServiceServer + files map[string][]byte +} + +func (s *rawReadTestServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { + data, ok := s.files[req.GetPath()] + if !ok { + return status.Errorf(codes.NotFound, "open: open %s: no such file or directory", req.GetPath()) + } + if len(data) == 0 { + return nil + } + if err := stream.Send(&pb.DataChunk{Data: data[:1]}); err != nil { + return err + } + if len(data) > 1 { + if err := stream.Send(&pb.DataChunk{Data: data[1:]}); err != nil { + return err + } + } + return nil +} + +func newTestReadRawClient(t *testing.T, files map[string][]byte) *Client { + t.Helper() + + lis := bufconn.Listen(testBufSize) + srv := grpc.NewServer() + pb.RegisterContainerServiceServer(srv, &rawReadTestServer{files: files}) + + done := make(chan struct{}) + go func() { + defer close(done) + _ = srv.Serve(lis) + }() + t.Cleanup(func() { + srv.Stop() + <-done + }) + + dialer := func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + } + conn, err := grpc.NewClient("passthrough://bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("grpc.NewClient: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + + return NewClientFromConn(conn) +} + +func TestClientReadRawMissingFileReturnsNotFoundImmediately(t *testing.T) { + t.Parallel() + + client := newTestReadRawClient(t, map[string][]byte{}) + _, err := client.ReadRaw(context.Background(), "/data/media/missing.jpg") + if err == nil { + t.Fatal("expected read raw to fail for missing file") + } + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } +} + +func TestClientReadRawPreservesFirstChunk(t *testing.T) { + t.Parallel() + + client := newTestReadRawClient(t, map[string][]byte{ + "/data/media/existing.jpg": []byte("hello"), + }) + reader, err := client.ReadRaw(context.Background(), "/data/media/existing.jpg") + if err != nil { + t.Fatalf("ReadRaw returned error: %v", err) + } + defer func() { _ = reader.Close() }() + + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("read raw reader failed: %v", err) + } + if got := string(data); got != "hello" { + t.Fatalf("expected full payload, got %q", got) + } +} + +func TestClientReadRawSupportsEmptyFile(t *testing.T) { + t.Parallel() + + client := newTestReadRawClient(t, map[string][]byte{ + "/data/media/empty.txt": {}, + }) + reader, err := client.ReadRaw(context.Background(), "/data/media/empty.txt") + if err != nil { + t.Fatalf("ReadRaw returned error: %v", err) + } + defer func() { _ = reader.Close() }() + + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("read raw empty reader failed: %v", err) + } + if len(data) != 0 { + t.Fatalf("expected empty payload, got %q", string(data)) + } +} diff --git a/internal/mcp/providers/message/provider.go b/internal/mcp/providers/message/provider.go index d45857cc..d1548312 100644 --- a/internal/mcp/providers/message/provider.go +++ b/internal/mcp/providers/message/provider.go @@ -193,8 +193,15 @@ func (p *Executor) callSend(ctx context.Context, session mcpgw.ToolSessionContex // Resolve top-level attachments parameter. if rawAttachments, ok := arguments["attachments"]; ok && rawAttachments != nil { - if arr, ok := rawAttachments.([]any); ok && len(arr) > 0 { - resolved := p.resolveAttachments(ctx, botID, arr) + items := normalizeAttachmentInputs(rawAttachments) + if items == nil { + return mcpgw.BuildToolErrorResult("attachments must be a string, object, or array"), nil + } + if len(items) > 0 { + resolved := p.resolveAttachments(ctx, botID, items) + if len(resolved) == 0 { + return mcpgw.BuildToolErrorResult("attachments could not be resolved"), nil + } outboundMessage.Attachments = append(outboundMessage.Attachments, resolved...) } } @@ -352,6 +359,28 @@ func (p *Executor) resolveAttachments(ctx context.Context, botID string, items [ return result } +func normalizeAttachmentInputs(raw any) []any { + switch v := raw.(type) { + case nil: + return nil + case []any: + if v == nil { + return []any{} + } + return v + case []string: + items := make([]any, 0, len(v)) + for _, item := range v { + items = append(items, item) + } + return items + case string, map[string]any: + return []any{v} + default: + return nil + } +} + // resolveAttachmentRef resolves a single path or URL to a channel.Attachment. func (p *Executor) resolveAttachmentRef(ctx context.Context, botID, ref, attType, name string) *channel.Attachment { ref = strings.TrimSpace(ref) diff --git a/internal/mcp/providers/message/provider_test.go b/internal/mcp/providers/message/provider_test.go index 9660ac12..cbab0ebb 100644 --- a/internal/mcp/providers/message/provider_test.go +++ b/internal/mcp/providers/message/provider_test.go @@ -3,6 +3,7 @@ package message import ( "context" "errors" + "strings" "testing" "github.com/memohai/memoh/internal/channel" @@ -41,6 +42,33 @@ func (f *fakeResolver) ParseChannelType(_ string) (channel.ChannelType, error) { return f.ct, nil } +type fakeAssetResolver struct { + getAsset AssetMeta + getErr error + ingestAsset AssetMeta + ingestErr error +} + +func (f *fakeAssetResolver) GetByStorageKey(context.Context, string, string) (AssetMeta, error) { + if f.getErr != nil { + return AssetMeta{}, f.getErr + } + if strings.TrimSpace(f.getAsset.ContentHash) != "" { + return f.getAsset, nil + } + return AssetMeta{}, errors.New("not found") +} + +func (f *fakeAssetResolver) IngestContainerFile(context.Context, string, string) (AssetMeta, error) { + if f.ingestErr != nil { + return AssetMeta{}, f.ingestErr + } + if strings.TrimSpace(f.ingestAsset.ContentHash) != "" { + return f.ingestAsset, nil + } + return AssetMeta{}, errors.New("ingest disabled") +} + // --- send tests --- func TestExecutor_ListTools_NilDeps(t *testing.T) { @@ -292,6 +320,102 @@ func TestExecutor_CallTool_NoReplyTo(t *testing.T) { } } +func TestExecutor_CallTool_TopLevelAttachmentsArePreserved(t *testing.T) { + tests := []struct { + name string + attachments any + }{ + {name: "string array", attachments: []string{"https://example.com/test.jpg"}}, + {name: "single string", attachments: "https://example.com/test.jpg"}, + {name: "object", attachments: map[string]any{"url": "https://example.com/test.jpg"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("qq")} + exec := NewExecutor(nil, sender, nil, resolver, &fakeAssetResolver{}) + session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "qq"} + + result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ + "platform": "qq", + "target": "3fe2bad9-3eae-4f23-872c-b7a63662aa00", + "text": "test.jpg from QQ", + "attachments": tt.attachments, + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + if len(sender.lastReq.Message.Attachments) != 1 { + t.Fatalf("expected 1 attachment, got %d", len(sender.lastReq.Message.Attachments)) + } + att := sender.lastReq.Message.Attachments[0] + if att.URL != "https://example.com/test.jpg" { + t.Fatalf("unexpected attachment url: %q", att.URL) + } + if att.Type != channel.AttachmentImage { + t.Fatalf("unexpected attachment type: %q", att.Type) + } + }) + } +} + +func TestExecutor_CallTool_AllowsEmptyTopLevelAttachmentsArray(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("qq")} + exec := NewExecutor(nil, sender, nil, resolver, &fakeAssetResolver{}) + session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "qq"} + + result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ + "platform": "qq", + "target": "3fe2bad9-3eae-4f23-872c-b7a63662aa00", + "text": "hello", + "attachments": []any{}, + }) + if err != nil { + t.Fatal(err) + } + if err := mcpgw.PayloadError(result); err != nil { + t.Fatal(err) + } + if len(sender.lastReq.Message.Attachments) != 0 { + t.Fatalf("expected no attachments, got %d", len(sender.lastReq.Message.Attachments)) + } +} + +func TestExecutor_CallTool_DataAttachmentsFailWhenIngestFails(t *testing.T) { + sender := &fakeSender{} + resolver := &fakeResolver{ct: channel.ChannelType("qq")} + exec := NewExecutor(nil, sender, nil, resolver, &fakeAssetResolver{ingestErr: errors.New("ingest disabled")}) + session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "qq"} + + result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ + "platform": "qq", + "target": "3fe2bad9-3eae-4f23-872c-b7a63662aa00", + "text": "test.jpg from QQ", + "attachments": []string{"/data/test.jpg"}, + }) + if err != nil { + t.Fatal(err) + } + if isErr, _ := result["isError"].(bool); !isErr { + t.Fatal("expected attachment resolution error") + } + payloadMsg := "" + if content, ok := result["content"].([]map[string]any); ok && len(content) > 0 { + payloadMsg, _ = content[0]["text"].(string) + } + if !strings.Contains(payloadMsg, "attachments could not be resolved") { + t.Fatalf("unexpected error: %v", payloadMsg) + } + if len(sender.lastReq.Message.Attachments) != 0 { + t.Fatalf("expected no outbound attachments, got %d", len(sender.lastReq.Message.Attachments)) + } +} + // --- react tests --- func TestExecutor_React_NilReactor(t *testing.T) { diff --git a/internal/mcp/providers/skill/provider.go b/internal/mcp/providers/skill/provider.go index 932022b6..d4439cb6 100644 --- a/internal/mcp/providers/skill/provider.go +++ b/internal/mcp/providers/skill/provider.go @@ -24,7 +24,7 @@ func NewExecutor(log *slog.Logger) *Executor { } } -func (e *Executor) ListTools(_ context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (*Executor) ListTools(_ context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if session.IsSubagent { return []mcpgw.ToolDescriptor{}, nil } @@ -50,7 +50,7 @@ func (e *Executor) ListTools(_ context.Context, session mcpgw.ToolSessionContext }, nil } -func (e *Executor) CallTool(_ context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { +func (*Executor) CallTool(_ context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { if toolName != toolUseSkill { return nil, mcpgw.ErrToolNotFound } diff --git a/internal/mcp/providers/subagent/provider.go b/internal/mcp/providers/subagent/provider.go index 39a2e0b7..19492e54 100644 --- a/internal/mcp/providers/subagent/provider.go +++ b/internal/mcp/providers/subagent/provider.go @@ -9,6 +9,7 @@ import ( "io" "log/slog" "net/http" + "slices" "strings" "time" @@ -197,7 +198,8 @@ func (e *Executor) callQuery(ctx context.Context, session mcpgw.ToolSessionConte return mcpgw.BuildToolErrorResult(fmt.Sprintf("subagent query failed: %v", err)), nil } - updatedMessages := append(target.Messages, gwResp.Messages...) + updatedMessages := slices.Clone(target.Messages) + updatedMessages = append(updatedMessages, gwResp.Messages...) usage := mergeUsage(target.Usage, gwResp.Usage) if _, err := e.service.UpdateContext(ctx, target.ID, subagentsvc.UpdateContextRequest{ Messages: updatedMessages, diff --git a/internal/mcp/providers/webfetch/provider.go b/internal/mcp/providers/webfetch/provider.go index ecf5a8b3..b2bff4db 100644 --- a/internal/mcp/providers/webfetch/provider.go +++ b/internal/mcp/providers/webfetch/provider.go @@ -39,7 +39,7 @@ func NewExecutor(log *slog.Logger) *Executor { } } -func (e *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (*Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { return []mcpgw.ToolDescriptor{ { Name: toolWebFetch, @@ -134,7 +134,7 @@ func detectFormat(contentType string) string { } } -func (e *Executor) processJSON(fetchedURL, contentType string, body []byte) (map[string]any, error) { +func (*Executor) processJSON(fetchedURL, contentType string, body []byte) (map[string]any, error) { var data any if err := json.Unmarshal(body, &data); err != nil { return mcpgw.BuildToolErrorResult("Failed to parse JSON"), nil @@ -148,7 +148,7 @@ func (e *Executor) processJSON(fetchedURL, contentType string, body []byte) (map }), nil } -func (e *Executor) processXML(fetchedURL, contentType string, body []byte) (map[string]any, error) { +func (*Executor) processXML(fetchedURL, contentType string, body []byte) (map[string]any, error) { content := string(body) if len(content) > maxTextContent { content = content[:maxTextContent] @@ -202,7 +202,7 @@ func (e *Executor) processHTML(fetchedURL, contentType string, body []byte) (map }), nil } -func (e *Executor) processText(fetchedURL, contentType string, body []byte) (map[string]any, error) { +func (*Executor) processText(fetchedURL, contentType string, body []byte) (map[string]any, error) { content := string(body) length := len(content) if length > maxTextContent { diff --git a/packages/agent/src/agent.test.ts b/packages/agent/src/agent.test.ts new file mode 100644 index 00000000..475c7ce9 --- /dev/null +++ b/packages/agent/src/agent.test.ts @@ -0,0 +1,88 @@ +import { describe, expect, it } from 'vitest' +import { createImagePartFromAttachment } from './utils/image-parts' + +describe('createImagePartFromAttachment', () => { + it('converts inline data URLs to binary image parts', () => { + const part = createImagePartFromAttachment({ + type: 'image', + transport: 'inline_data_url', + payload: 'data:image/png;base64,AQID', + }) + + expect(part?.type).toBe('image') + expect(part?.image).toBeInstanceOf(Uint8Array) + expect(Array.from(part?.image as Uint8Array)).toEqual([1, 2, 3]) + expect(part?.mediaType).toBe('image/png') + }) + + it('keeps public URLs as URL objects', () => { + const part = createImagePartFromAttachment({ + type: 'image', + transport: 'public_url', + payload: 'https://example.com/demo.png', + }) + + expect(part?.image).toBeInstanceOf(URL) + expect(String(part?.image)).toBe('https://example.com/demo.png') + }) + + it('falls back to string payloads for malformed public URLs', () => { + const part = createImagePartFromAttachment({ + type: 'image', + transport: 'public_url', + payload: 'https://', + mime: 'image/png', + }) + + expect(part?.image).toBe('https://') + expect(part?.mediaType).toBe('image/png') + }) + + it('keeps inline payload strings when they are not data URLs', () => { + const part = createImagePartFromAttachment({ + type: 'image', + transport: 'inline_data_url', + payload: 'AQID', + mime: 'image/png', + }) + + expect(part?.image).toBe('AQID') + expect(part?.mediaType).toBe('image/png') + }) + + it('falls back to string payloads for malformed non-base64 data URLs', () => { + const payload = 'data:image/png,a%ZZ' + const part = createImagePartFromAttachment({ + type: 'image', + transport: 'inline_data_url', + payload, + mime: 'image/png', + }) + + expect(part?.image).toBe(payload) + expect(part?.mediaType).toBe('image/png') + }) + + it('falls back to string payloads for malformed base64 data URLs', () => { + const payload = 'data:image/png;base64,%%%' + const part = createImagePartFromAttachment({ + type: 'image', + transport: 'inline_data_url', + payload, + mime: 'image/png', + }) + + expect(part?.image).toBe(payload) + expect(part?.mediaType).toBe('image/png') + }) + + it('skips tool file references', () => { + const part = createImagePartFromAttachment({ + type: 'image', + transport: 'tool_file_ref', + payload: '/data/media/demo.png', + }) + + expect(part).toBeNull() + }) +}) diff --git a/packages/agent/src/agent.ts b/packages/agent/src/agent.ts index 75430f05..7fdcb2f8 100644 --- a/packages/agent/src/agent.ts +++ b/packages/agent/src/agent.ts @@ -29,6 +29,7 @@ import { dedupeAttachments, AttachmentsStreamExtractor, } from './utils/attachments' +import { createImagePartFromAttachment } from './utils/image-parts' import type { GatewayInputAttachment } from './types/attachment' import { getMCPTools } from './tools/mcp' import { buildIdentityHeaders } from './utils/headers' @@ -80,12 +81,8 @@ const buildStepUsages = ( export const buildNativeImageParts = (attachments: GatewayInputAttachment[]): ImagePart[] => { return attachments - .filter((attachment) => - attachment.type === 'image' && - (attachment.transport === 'inline_data_url' || attachment.transport === 'public_url') && - Boolean(attachment.payload), - ) - .map((attachment): ImagePart => ({ type: 'image', image: attachment.payload })) + .map((attachment) => createImagePartFromAttachment(attachment)) + .filter((attachment): attachment is ImagePart => attachment != null) } export const createAgent = ( diff --git a/packages/agent/src/utils/image-parts.ts b/packages/agent/src/utils/image-parts.ts new file mode 100644 index 00000000..db9c4ae9 --- /dev/null +++ b/packages/agent/src/utils/image-parts.ts @@ -0,0 +1,144 @@ +import type { ImagePart } from 'ai' +import type { GatewayInputAttachment } from '../types/attachment' + +type NativeImageAttachment = GatewayInputAttachment & { + type: 'image' + transport: 'inline_data_url' | 'public_url' +} + +type ImagePartPayload = string | Uint8Array | URL +const strictBase64Pattern = /^[A-Za-z0-9+/]*={0,2}$/ + +const normalizeMediaType = (value?: string): string | undefined => { + const mediaType = typeof value === 'string' ? value.trim() : '' + return mediaType || undefined +} + +const createImagePart = (image: ImagePartPayload, mediaType?: string): ImagePart => { + const normalizedMediaType = normalizeMediaType(mediaType) + if (normalizedMediaType == null) { + return { type: 'image', image } + } + return { type: 'image', image, mediaType: normalizedMediaType } +} + +const decodeBase64Strict = (value: string): Buffer | null => { + const normalized = value.replace(/\s+/g, '') + if (normalized === '' || !strictBase64Pattern.test(normalized)) { + return null + } + + const firstPadding = normalized.indexOf('=') + if (firstPadding >= 0) { + if (/[A-Za-z0-9+/]/.test(normalized.slice(firstPadding))) { + return null + } + if (normalized.length-firstPadding > 2 || normalized.length % 4 !== 0) { + return null + } + } + else if (normalized.length % 4 === 1) { + return null + } + + const padded = firstPadding >= 0 + ? normalized + : normalized + '='.repeat((4 - (normalized.length % 4)) % 4) + + const decoded = Buffer.from(padded, 'base64') + const canonical = decoded.toString('base64').replace(/=+$/g, '') + const input = normalized.replace(/=+$/g, '') + if (canonical !== input) { + return null + } + + return decoded +} + +const parseDataUrl = (payload: string): { bytes: Uint8Array; mediaType?: string } | null => { + const trimmed = payload.trim() + if (!trimmed.toLowerCase().startsWith('data:')) { + return null + } + + const commaIndex = trimmed.indexOf(',') + if (commaIndex < 0) { + return null + } + + const header = trimmed.slice(5, commaIndex) + const body = trimmed.slice(commaIndex + 1) + const segments = header.split(';').map((segment) => segment.trim()).filter(Boolean) + const mediaType = normalizeMediaType(segments.find((segment) => segment.includes('/'))) + const isBase64 = segments.some((segment) => segment.toLowerCase() === 'base64') + let buffer: Buffer + if (isBase64) { + const decoded = decodeBase64Strict(body) + if (decoded == null) { + return null + } + buffer = decoded + } + else { + try { + buffer = Buffer.from(decodeURIComponent(body), 'utf8') + } + catch { + return null + } + } + + return { + bytes: new Uint8Array(buffer), + mediaType, + } +} + +const isNativeImageAttachment = ( + attachment: GatewayInputAttachment, +): attachment is NativeImageAttachment => { + if (attachment.type !== 'image') { + return false + } + if (attachment.transport !== 'inline_data_url' && attachment.transport !== 'public_url') { + return false + } + return typeof attachment.payload === 'string' && attachment.payload.trim() !== '' +} + +const createInlineDataImagePart = (payload: string, mediaType?: string): ImagePart => { + const parsed = parseDataUrl(payload) + if (parsed != null) { + return createImagePart(parsed.bytes, mediaType ?? parsed.mediaType) + } + return createImagePart(payload, mediaType) +} + +const createPublicURLImagePart = (payload: string, mediaType?: string): ImagePart => { + try { + return createImagePart(new URL(payload), mediaType) + } + catch { + return createImagePart(payload, mediaType) + } +} + +export const createBinaryImagePart = (bytes: Uint8Array, mediaType?: string): ImagePart => { + return createImagePart(bytes, mediaType) +} + +export const createImagePartFromAttachment = ( + attachment: GatewayInputAttachment, +): ImagePart | null => { + if (!isNativeImageAttachment(attachment)) { + return null + } + + const payload = attachment.payload.trim() + switch (attachment.transport) { + case 'public_url': + return createPublicURLImagePart(payload, attachment.mime) + case 'inline_data_url': + return createInlineDataImagePart(payload, attachment.mime) + } +} diff --git a/packages/agent/src/utils/read-media-injector.test.ts b/packages/agent/src/utils/read-media-injector.test.ts index f053565d..740eda80 100644 --- a/packages/agent/src/utils/read-media-injector.test.ts +++ b/packages/agent/src/utils/read-media-injector.test.ts @@ -10,6 +10,11 @@ const baseModelConfig: ModelConfig = { input: [ModelInput.Image], } +const createToolOptions = (toolCallId: string) => ({ + toolCallId, + messages: [], +}) + describe('read_media runtime', () => { it('caches image and injects it into messages', async () => { const fs = { @@ -23,10 +28,10 @@ describe('read_media runtime', () => { fs, systemPrompt: 'sys', }) - const readMedia = tools.read_media - const output = await readMedia.execute( + const executeReadMedia = tools.read_media.execute! + const output = await executeReadMedia( { path: '/data/media/a.png' }, - { toolCallId: 'call-1' }, + createToolOptions('call-1'), ) expect((output as { ok?: boolean }).ok).toBe(true) const prepared = await prepareStep({ @@ -38,9 +43,11 @@ describe('read_media runtime', () => { }) const injected = prepared.messages?.[1] expect(injected?.role).toBe('user') - const content = injected?.content as Array<{ type?: string; image?: string }> + const content = injected?.content as Array<{ type?: string; image?: Uint8Array; mediaType?: string }> expect(content?.[0]?.type).toBe('image') - expect(content?.[0]?.image?.startsWith('data:image/png;base64,')).toBe(true) + expect(content?.[0]?.image).toBeInstanceOf(Uint8Array) + expect(Array.from(content?.[0]?.image ?? [])).toEqual([1, 2, 3]) + expect(content?.[0]?.mediaType).toBe('image/png') }) it('returns error result on download failure', async () => { @@ -54,10 +61,10 @@ describe('read_media runtime', () => { fs, systemPrompt: 'sys', }) - const readMedia = tools.read_media - const output = await readMedia.execute( + const executeReadMedia = tools.read_media.execute! + const output = await executeReadMedia( { path: '/data/media/a.png' }, - { toolCallId: 'call-2' }, + createToolOptions('call-2'), ) expect((output as { isError?: boolean }).isError).toBe(true) const prepared = await prepareStep({ @@ -84,14 +91,14 @@ describe('read_media runtime', () => { fs, systemPrompt: 'sys', }) - const readMedia = tools.read_media - const first = readMedia.execute( + const executeReadMedia = tools.read_media.execute! + const first = executeReadMedia( { path: '/data/media/a.png' }, - { toolCallId: 'call-1' }, + createToolOptions('call-1'), ) - const second = readMedia.execute( + const second = executeReadMedia( { path: '/data/media/b.png' }, - { toolCallId: 'call-2' }, + createToolOptions('call-2'), ) await Promise.all([first, second]) const prepared = await prepareStep({ @@ -102,8 +109,10 @@ describe('read_media runtime', () => { experimental_context: undefined, }) const injected = prepared.messages?.[1] - const content = injected?.content as Array<{ type?: string; image?: string }> - expect(content?.[0]?.image?.includes('AQ==')).toBe(true) - expect(content?.[1]?.image?.includes('Ag==')).toBe(true) + const content = injected?.content as Array<{ type?: string; image?: Uint8Array; mediaType?: string }> + expect(Array.from(content?.[0]?.image ?? [])).toEqual([1]) + expect(Array.from(content?.[1]?.image ?? [])).toEqual([2]) + expect(content?.[0]?.mediaType).toBe('image/png') + expect(content?.[1]?.mediaType).toBe('image/png') }) }) diff --git a/packages/agent/src/utils/read-media-injector.ts b/packages/agent/src/utils/read-media-injector.ts index da227712..cf9a9446 100644 --- a/packages/agent/src/utils/read-media-injector.ts +++ b/packages/agent/src/utils/read-media-injector.ts @@ -1,6 +1,7 @@ import { ImagePart, PrepareStepFunction, ToolSet, UserModelMessage, tool } from 'ai' import { z } from 'zod' import { ModelConfig, ModelInput, hasInputModality } from '../types/model' +import { createBinaryImagePart } from './image-parts' const READ_MEDIA_TOOL_NAME = 'read_media' @@ -8,10 +9,6 @@ const isImageMime = (mime: string): boolean => { return mime.trim().toLowerCase().startsWith('image/') } -const toImagePart = (payload: string): ImagePart => { - return { type: 'image', image: payload } as ImagePart -} - type ReadMediaFS = { download: (path: string) => Promise } @@ -22,20 +19,19 @@ const buildReadMediaToolError = (message: string) => ({ structuredContent: { ok: false, error: message }, }) -const loadImageAsDataUrl = async ( +const loadImageBytes = async ( fs: ReadMediaFS, path: string, -): Promise<{ ok: true; dataUrl: string; mime: string } | { ok: false; error: string }> => { +): Promise<{ ok: true; bytes: Uint8Array; mime: string } | { ok: false; error: string }> => { try { const response = await fs.download(path) - const arrayBuffer = await response.arrayBuffer() - const base64 = Buffer.from(arrayBuffer).toString('base64') + const bytes = new Uint8Array(await response.arrayBuffer()) const header = response.headers.get('content-type') ?? '' const mime = header.split(';')[0]?.trim() ?? '' if (!mime || !isImageMime(mime)) { return { ok: false, error: 'read_media only supports image files' } } - return { ok: true, dataUrl: `data:${mime};base64,${base64}`, mime } + return { ok: true, bytes, mime } } catch (error) { console.error(error) const message = error instanceof Error ? error.message : String(error) @@ -77,11 +73,11 @@ export const createPrepareStepWithReadMedia = (params: { cachedImages.set(toolCallId, null) callOrder.push(toolCallId) } - const loaded = await loadImageAsDataUrl(params.fs, trimmedPath) + const loaded = await loadImageBytes(params.fs, trimmedPath) if (!loaded.ok) { return buildReadMediaToolError(loaded.error) } - cachedImages.set(toolCallId, toImagePart(loaded.dataUrl) as ImagePart) + cachedImages.set(toolCallId, createBinaryImagePart(loaded.bytes, loaded.mime)) return { ok: true, path: trimmedPath, mime: loaded.mime } }, })