From bc47655309b1378553e437f60a218daa69388f65 Mon Sep 17 00:00:00 2001 From: BBQ Date: Tue, 10 Mar 2026 17:43:47 +0800 Subject: [PATCH] fix(wecom): align adapter with channel stream behavior Migrate the imported WeCom adapter to current channel interfaces and stabilize stream delivery by preventing heartbeat/reply ACK timeout regressions and post-final overwrite updates. --- cmd/agent/main.go | 2 + cmd/memoh/serve.go | 2 + .../wecom/adapter_integration_test.go | 157 +++++++ .../channel/adapters/wecom/callback_cache.go | 77 ++++ .../adapters/wecom/callback_cache_test.go | 29 ++ internal/channel/adapters/wecom/config.go | 210 +++++++++ .../channel/adapters/wecom/config_test.go | 27 ++ internal/channel/adapters/wecom/crypto.go | 52 +++ .../channel/adapters/wecom/crypto_test.go | 50 ++ .../channel/adapters/wecom/http_client.go | 167 +++++++ .../adapters/wecom/http_client_test.go | 36 ++ internal/channel/adapters/wecom/inbound.go | 338 ++++++++++++++ .../channel/adapters/wecom/inbound_test.go | 103 +++++ internal/channel/adapters/wecom/outbound.go | 309 +++++++++++++ .../channel/adapters/wecom/outbound_test.go | 239 ++++++++++ internal/channel/adapters/wecom/protocol.go | 256 +++++++++++ .../channel/adapters/wecom/protocol_test.go | 29 ++ internal/channel/adapters/wecom/wecom.go | 434 ++++++++++++++++++ internal/channel/adapters/wecom/wecom_test.go | 52 +++ internal/channel/adapters/wecom/ws_client.go | 396 ++++++++++++++++ .../channel/adapters/wecom/ws_client_test.go | 178 +++++++ 21 files changed, 3143 insertions(+) create mode 100644 internal/channel/adapters/wecom/adapter_integration_test.go create mode 100644 internal/channel/adapters/wecom/callback_cache.go create mode 100644 internal/channel/adapters/wecom/callback_cache_test.go create mode 100644 internal/channel/adapters/wecom/config.go create mode 100644 internal/channel/adapters/wecom/config_test.go create mode 100644 internal/channel/adapters/wecom/crypto.go create mode 100644 internal/channel/adapters/wecom/crypto_test.go create mode 100644 internal/channel/adapters/wecom/http_client.go create mode 100644 internal/channel/adapters/wecom/http_client_test.go create mode 100644 internal/channel/adapters/wecom/inbound.go create mode 100644 internal/channel/adapters/wecom/inbound_test.go create mode 100644 internal/channel/adapters/wecom/outbound.go create mode 100644 internal/channel/adapters/wecom/outbound_test.go create mode 100644 internal/channel/adapters/wecom/protocol.go create mode 100644 internal/channel/adapters/wecom/protocol_test.go create mode 100644 internal/channel/adapters/wecom/wecom.go create mode 100644 internal/channel/adapters/wecom/wecom_test.go create mode 100644 internal/channel/adapters/wecom/ws_client.go create mode 100644 internal/channel/adapters/wecom/ws_client_test.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index dbabbcdb..968cc290 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -30,6 +30,7 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/qq" "github.com/memohai/memoh/internal/channel/adapters/telegram" + "github.com/memohai/memoh/internal/channel/adapters/wecom" "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/channel/inbound" "github.com/memohai/memoh/internal/channel/route" @@ -402,6 +403,7 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService feishuAdapter := feishu.NewFeishuAdapter(log) feishuAdapter.SetAssetOpener(mediaService) registry.MustRegister(feishuAdapter) + registry.MustRegister(wecom.NewWeComAdapter(log)) registry.MustRegister(local.NewCLIAdapter(hub)) registry.MustRegister(local.NewWebAdapter(hub)) return registry diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 28e3e08b..abecf949 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -30,6 +30,7 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/feishu" "github.com/memohai/memoh/internal/channel/adapters/local" "github.com/memohai/memoh/internal/channel/adapters/telegram" + "github.com/memohai/memoh/internal/channel/adapters/wecom" "github.com/memohai/memoh/internal/channel/identities" "github.com/memohai/memoh/internal/channel/inbound" "github.com/memohai/memoh/internal/channel/route" @@ -285,6 +286,7 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService feishuAdapter := feishu.NewFeishuAdapter(log) feishuAdapter.SetAssetOpener(mediaService) registry.MustRegister(feishuAdapter) + registry.MustRegister(wecom.NewWeComAdapter(log)) registry.MustRegister(local.NewCLIAdapter(hub)) registry.MustRegister(local.NewWebAdapter(hub)) return registry diff --git a/internal/channel/adapters/wecom/adapter_integration_test.go b/internal/channel/adapters/wecom/adapter_integration_test.go new file mode 100644 index 00000000..d90f058e --- /dev/null +++ b/internal/channel/adapters/wecom/adapter_integration_test.go @@ -0,0 +1,157 @@ +package wecom + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/memohai/memoh/internal/channel" +) + +func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{} + receivedRespond := make(chan WSFrame, 1) + serverErr := make(chan error, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + select { + case serverErr <- err: + default: + } + return + } + defer conn.Close() + + var subscribeFrame WSFrame + if err := conn.ReadJSON(&subscribeFrame); err != nil { + select { + case serverErr <- err: + default: + } + return + } + if err := conn.WriteJSON(WSFrame{ + Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID}, + ErrCode: 0, + }); err != nil { + select { + case serverErr <- err: + default: + } + return + } + + body, _ := json.Marshal(MessageCallbackBody{ + MsgID: "msg_1", + ChatID: "chat_1", + ChatType: "group", + CreateTime: time.Now().UnixMilli(), + From: CallbackFrom{UserID: "u1"}, + MsgType: "text", + ResponseURL: "https://example.com/resp", + Text: &MessageText{Content: "hello"}, + }) + if err := conn.WriteJSON(WSFrame{ + Cmd: WSCmdMsgCallback, + Headers: WSHeaders{ReqID: "callback_req_id"}, + Body: body, + }); err != nil { + select { + case serverErr <- err: + default: + } + return + } + + var respondFrame WSFrame + if err := conn.ReadJSON(&respondFrame); err != nil { + select { + case serverErr <- err: + default: + } + return + } + select { + case receivedRespond <- respondFrame: + default: + } + _ = conn.WriteJSON(WSFrame{ + Headers: WSHeaders{ReqID: respondFrame.Headers.ReqID}, + ErrCode: 0, + }) + })) + defer server.Close() + + adapter := NewWeComAdapter(nil) + cfg := channel.ChannelConfig{ + Credentials: map[string]any{ + "botId": "bot", + "secret": "sec", + "wsUrl": "ws" + strings.TrimPrefix(server.URL, "http"), + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + inboundCh := make(chan channel.InboundMessage, 1) + conn, err := adapter.Connect(ctx, cfg, func(_ context.Context, _ channel.ChannelConfig, msg channel.InboundMessage) error { + select { + case inboundCh <- msg: + default: + } + return nil + }) + if err != nil { + t.Fatalf("connect error: %v", err) + } + defer conn.Stop(context.Background()) + + select { + case inbound := <-inboundCh: + if inbound.Message.ID != "msg_1" { + t.Fatalf("unexpected inbound message id: %s", inbound.Message.ID) + } + case err := <-serverErr: + t.Fatalf("server error: %v", err) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting inbound callback") + } + + err = adapter.Send(context.Background(), cfg, channel.OutboundMessage{ + Target: "chat_id:chat_1", + Message: channel.Message{ + Format: channel.MessageFormatPlain, + Text: "reply content", + Reply: &channel.ReplyRef{ + MessageID: "msg_1", + }, + }, + }) + if err != nil { + t.Fatalf("send error: %v", err) + } + + select { + case frame := <-receivedRespond: + if frame.Cmd != WSCmdRespond { + t.Fatalf("expected cmd=%s got=%s", WSCmdRespond, frame.Cmd) + } + if frame.Headers.ReqID != "callback_req_id" { + t.Fatalf("expected req_id callback_req_id got=%s", frame.Headers.ReqID) + } + case err := <-serverErr: + t.Fatalf("server error: %v", err) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting respond frame") + } +} + diff --git a/internal/channel/adapters/wecom/callback_cache.go b/internal/channel/adapters/wecom/callback_cache.go new file mode 100644 index 00000000..379d5049 --- /dev/null +++ b/internal/channel/adapters/wecom/callback_cache.go @@ -0,0 +1,77 @@ +package wecom + +import ( + "strings" + "sync" + "time" +) + +type callbackContext struct { + ReqID string + ResponseURL string + ChatID string + UserID string + CreatedAt time.Time +} + +type callbackContextCache struct { + mu sync.RWMutex + items map[string]callbackContext + ttl time.Duration +} + +func newCallbackContextCache(ttl time.Duration) *callbackContextCache { + if ttl <= 0 { + ttl = 24 * time.Hour + } + return &callbackContextCache{ + items: make(map[string]callbackContext), + ttl: ttl, + } +} + +func (c *callbackContextCache) Put(messageID string, ctx callbackContext) { + key := strings.TrimSpace(messageID) + if key == "" { + return + } + if ctx.CreatedAt.IsZero() { + ctx.CreatedAt = time.Now().UTC() + } + c.mu.Lock() + c.items[key] = ctx + c.gcLocked() + c.mu.Unlock() +} + +func (c *callbackContextCache) Get(messageID string) (callbackContext, bool) { + key := strings.TrimSpace(messageID) + if key == "" { + return callbackContext{}, false + } + c.mu.RLock() + item, ok := c.items[key] + c.mu.RUnlock() + if !ok { + return callbackContext{}, false + } + if time.Since(item.CreatedAt) > c.ttl { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() + return callbackContext{}, false + } + return item, true +} + +func (c *callbackContextCache) gcLocked() { + if len(c.items) < 512 { + return + } + now := time.Now().UTC() + for key, item := range c.items { + if now.Sub(item.CreatedAt) > c.ttl { + delete(c.items, key) + } + } +} diff --git a/internal/channel/adapters/wecom/callback_cache_test.go b/internal/channel/adapters/wecom/callback_cache_test.go new file mode 100644 index 00000000..4278cd82 --- /dev/null +++ b/internal/channel/adapters/wecom/callback_cache_test.go @@ -0,0 +1,29 @@ +package wecom + +import ( + "testing" + "time" +) + +func TestCallbackContextCache_PutGet(t *testing.T) { + cache := newCallbackContextCache(1 * time.Hour) + cache.Put("m1", callbackContext{ReqID: "r1"}) + got, ok := cache.Get("m1") + if !ok { + t.Fatal("expected cache hit") + } + if got.ReqID != "r1" { + t.Fatalf("unexpected req id: %q", got.ReqID) + } +} + +func TestCallbackContextCache_Expires(t *testing.T) { + cache := newCallbackContextCache(1 * time.Second) + cache.Put("m1", callbackContext{ + ReqID: "r1", + CreatedAt: time.Now().Add(-2 * time.Second), + }) + if _, ok := cache.Get("m1"); ok { + t.Fatal("expected cache miss due to expiry") + } +} diff --git a/internal/channel/adapters/wecom/config.go b/internal/channel/adapters/wecom/config.go new file mode 100644 index 00000000..4c595e60 --- /dev/null +++ b/internal/channel/adapters/wecom/config.go @@ -0,0 +1,210 @@ +package wecom + +import ( + "fmt" + "strconv" + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +type Config struct { + BotID string + Secret string + WSURL string + HeartbeatSeconds int + AckTimeoutSeconds int + WriteTimeoutSeconds int + ReadTimeoutSeconds int +} + +type UserConfig struct { + ChatID string + UserID string +} + +func normalizeConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseConfig(raw) + if err != nil { + return nil, err + } + out := map[string]any{ + "botId": cfg.BotID, + "secret": cfg.Secret, + } + if cfg.WSURL != "" { + out["wsUrl"] = cfg.WSURL + } + if cfg.HeartbeatSeconds > 0 { + out["heartbeatSeconds"] = cfg.HeartbeatSeconds + } + if cfg.AckTimeoutSeconds > 0 { + out["ackTimeoutSeconds"] = cfg.AckTimeoutSeconds + } + if cfg.WriteTimeoutSeconds > 0 { + out["writeTimeoutSeconds"] = cfg.WriteTimeoutSeconds + } + if cfg.ReadTimeoutSeconds > 0 { + out["readTimeoutSeconds"] = cfg.ReadTimeoutSeconds + } + return out, nil +} + +func normalizeUserConfig(raw map[string]any) (map[string]any, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return nil, err + } + out := map[string]any{} + if cfg.ChatID != "" { + out["chat_id"] = cfg.ChatID + } + if cfg.UserID != "" { + out["user_id"] = cfg.UserID + } + return out, nil +} + +func parseConfig(raw map[string]any) (Config, error) { + cfg := Config{ + BotID: strings.TrimSpace(channel.ReadString(raw, "botId", "bot_id")), + Secret: strings.TrimSpace(channel.ReadString(raw, "secret")), + WSURL: strings.TrimSpace(channel.ReadString(raw, "wsUrl", "ws_url")), + } + if value, ok := readInt(raw, "heartbeatSeconds", "heartbeat_seconds"); ok { + cfg.HeartbeatSeconds = value + } + if value, ok := readInt(raw, "ackTimeoutSeconds", "ack_timeout_seconds"); ok { + cfg.AckTimeoutSeconds = value + } + if value, ok := readInt(raw, "writeTimeoutSeconds", "write_timeout_seconds"); ok { + cfg.WriteTimeoutSeconds = value + } + if value, ok := readInt(raw, "readTimeoutSeconds", "read_timeout_seconds"); ok { + cfg.ReadTimeoutSeconds = value + } + if cfg.BotID == "" || cfg.Secret == "" { + return Config{}, fmt.Errorf("wecom botId and secret are required") + } + return cfg, nil +} + +func parseUserConfig(raw map[string]any) (UserConfig, error) { + cfg := UserConfig{ + ChatID: strings.TrimSpace(channel.ReadString(raw, "chatId", "chat_id")), + UserID: strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")), + } + if cfg.ChatID == "" && cfg.UserID == "" { + return UserConfig{}, fmt.Errorf("wecom user config requires chat_id or user_id") + } + return cfg, nil +} + +func resolveTarget(raw map[string]any) (string, error) { + cfg, err := parseUserConfig(raw) + if err != nil { + return "", err + } + if cfg.ChatID != "" { + return "chat_id:" + cfg.ChatID, nil + } + return "user_id:" + cfg.UserID, nil +} + +func normalizeTarget(raw string) string { + kind, id, ok := parseTarget(raw) + if !ok { + return "" + } + return kind + ":" + id +} + +func parseTarget(raw string) (kind string, id string, ok bool) { + v := strings.TrimSpace(raw) + if v == "" { + return "", "", false + } + v = strings.TrimPrefix(v, "wecom:") + v = strings.TrimPrefix(v, "workwx:") + v = strings.TrimSpace(v) + lv := strings.ToLower(v) + switch { + case strings.HasPrefix(lv, "chat_id:"): + id = strings.TrimSpace(v[len("chat_id:"):]) + return "chat_id", id, id != "" + case strings.HasPrefix(lv, "chat:"): + id = strings.TrimSpace(v[len("chat:"):]) + return "chat_id", id, id != "" + case strings.HasPrefix(lv, "group:"): + id = strings.TrimSpace(v[len("group:"):]) + return "chat_id", id, id != "" + case strings.HasPrefix(lv, "user_id:"): + id = strings.TrimSpace(v[len("user_id:"):]) + return "user_id", id, id != "" + case strings.HasPrefix(lv, "user:"): + id = strings.TrimSpace(v[len("user:"):]) + return "user_id", id, id != "" + default: + return "chat_id", v, true + } +} + +func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { + cfg, err := parseUserConfig(raw) + if err != nil { + return false + } + if value := strings.TrimSpace(criteria.Attribute("chat_id")); value != "" && value == cfg.ChatID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { + return true + } + if criteria.SubjectID != "" && (criteria.SubjectID == cfg.ChatID || criteria.SubjectID == cfg.UserID) { + return true + } + return false +} + +func buildUserConfig(identity channel.Identity) map[string]any { + out := map[string]any{} + if v := strings.TrimSpace(identity.Attribute("chat_id")); v != "" { + out["chat_id"] = v + } + if v := strings.TrimSpace(identity.Attribute("user_id")); v != "" { + out["user_id"] = v + } + return out +} + +func readInt(raw map[string]any, keys ...string) (int, bool) { + for _, key := range keys { + value, ok := raw[key] + if !ok { + continue + } + switch v := value.(type) { + case int: + return v, true + case int32: + return int(v), true + case int64: + return int(v), true + case float64: + return int(v), true + case float32: + return int(v), true + case string: + trimmed := strings.TrimSpace(v) + if trimmed == "" { + continue + } + parsed, err := strconv.Atoi(trimmed) + if err != nil { + continue + } + return parsed, true + } + } + return 0, false +} diff --git a/internal/channel/adapters/wecom/config_test.go b/internal/channel/adapters/wecom/config_test.go new file mode 100644 index 00000000..ea89aa8c --- /dev/null +++ b/internal/channel/adapters/wecom/config_test.go @@ -0,0 +1,27 @@ +package wecom + +import "testing" + +func TestParseConfig(t *testing.T) { + cfg, err := parseConfig(map[string]any{ + "botId": "bot-1", + "secret": "sec-1", + }) + if err != nil { + t.Fatalf("parseConfig error = %v", err) + } + if cfg.BotID != "bot-1" || cfg.Secret != "sec-1" { + t.Fatalf("unexpected config: %+v", cfg) + } +} + +func TestParseTarget(t *testing.T) { + kind, id, ok := parseTarget("chat_id:abc") + if !ok || kind != "chat_id" || id != "abc" { + t.Fatalf("unexpected target parse result: ok=%v kind=%q id=%q", ok, kind, id) + } + kind, id, ok = parseTarget("user_id:zhangsan") + if !ok || kind != "user_id" || id != "zhangsan" { + t.Fatalf("unexpected target parse result: ok=%v kind=%q id=%q", ok, kind, id) + } +} diff --git a/internal/channel/adapters/wecom/crypto.go b/internal/channel/adapters/wecom/crypto.go new file mode 100644 index 00000000..ca60003e --- /dev/null +++ b/internal/channel/adapters/wecom/crypto.go @@ -0,0 +1,52 @@ +package wecom + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" +) + +func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error) { + if len(ciphertext) == 0 { + return nil, fmt.Errorf("ciphertext is empty") + } + key, err := base64.StdEncoding.DecodeString(aesKeyBase64) + if err != nil { + return nil, fmt.Errorf("decode aes key failed: %w", err) + } + if len(key) != 32 { + return nil, fmt.Errorf("invalid aes key length: %d", len(key)) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if len(ciphertext)%aes.BlockSize != 0 { + return nil, fmt.Errorf("invalid ciphertext block size") + } + iv := key[:aes.BlockSize] + out := make([]byte, len(ciphertext)) + cipher.NewCBCDecrypter(block, iv).CryptBlocks(out, ciphertext) + plain, err := pkcs7Unpad(out, 32) + if err != nil { + return nil, err + } + return plain, nil +} + +func pkcs7Unpad(data []byte, maxPad int) ([]byte, error) { + if len(data) == 0 { + return nil, fmt.Errorf("pkcs7 payload is empty") + } + pad := int(data[len(data)-1]) + if pad <= 0 || pad > maxPad || pad > len(data) { + return nil, fmt.Errorf("invalid pkcs7 padding length: %d", pad) + } + padding := bytes.Repeat([]byte{byte(pad)}, pad) + if !bytes.Equal(data[len(data)-pad:], padding) { + return nil, fmt.Errorf("invalid pkcs7 padding bytes") + } + return data[:len(data)-pad], nil +} diff --git a/internal/channel/adapters/wecom/crypto_test.go b/internal/channel/adapters/wecom/crypto_test.go new file mode 100644 index 00000000..00739864 --- /dev/null +++ b/internal/channel/adapters/wecom/crypto_test.go @@ -0,0 +1,50 @@ +package wecom + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "testing" +) + +func TestDecryptFileAES256CBC(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + plain := []byte("hello-wecom-aibot") + ciphertext := encryptPKCS7To32(t, key, plain) + + out, err := DecryptFileAES256CBC(ciphertext, base64.StdEncoding.EncodeToString(key)) + if err != nil { + t.Fatalf("DecryptFileAES256CBC error = %v", err) + } + if !bytes.Equal(out, plain) { + t.Fatalf("plaintext mismatch: got=%q want=%q", string(out), string(plain)) + } +} + +func encryptPKCS7To32(t *testing.T, key []byte, plain []byte) []byte { + t.Helper() + padded := pkcs7PadTo32(plain) + block, err := aes.NewCipher(key) + if err != nil { + t.Fatal(err) + } + iv := key[:aes.BlockSize] + out := make([]byte, len(padded)) + cipher.NewCBCEncrypter(block, iv).CryptBlocks(out, padded) + return out +} + +func pkcs7PadTo32(data []byte) []byte { + const blockSize = 32 + pad := blockSize - (len(data) % blockSize) + if pad == 0 { + pad = blockSize + } + out := make([]byte, len(data)+pad) + copy(out, data) + for i := len(data); i < len(out); i++ { + out[i] = byte(pad) + } + return out +} diff --git a/internal/channel/adapters/wecom/http_client.go b/internal/channel/adapters/wecom/http_client.go new file mode 100644 index 00000000..091048b7 --- /dev/null +++ b/internal/channel/adapters/wecom/http_client.go @@ -0,0 +1,167 @@ +package wecom + +import ( + "context" + "fmt" + "io" + "log/slog" + "mime" + "net/http" + "net/url" + "path" + "strings" + "time" +) + +type HTTPClientOptions struct { + Logger *slog.Logger + Client *http.Client + Transport *http.Transport + Timeout time.Duration + MaxIdleConns int + MaxIdleConnsPerHost int + IdleConnTimeout time.Duration + TLSHandshakeTimeout time.Duration + ResponseHeaderTimeout time.Duration +} + +type HTTPClient struct { + client *http.Client + logger *slog.Logger +} + +type DownloadedFile struct { + Data []byte + FileName string + ContentType string +} + +func NewHTTPClient(opts HTTPClientOptions) *HTTPClient { + if opts.Logger == nil { + opts.Logger = slog.Default() + } + if opts.Timeout <= 0 { + opts.Timeout = 20 * time.Second + } + if opts.IdleConnTimeout <= 0 { + opts.IdleConnTimeout = 90 * time.Second + } + if opts.TLSHandshakeTimeout <= 0 { + opts.TLSHandshakeTimeout = 10 * time.Second + } + if opts.ResponseHeaderTimeout <= 0 { + opts.ResponseHeaderTimeout = 15 * time.Second + } + if opts.MaxIdleConns <= 0 { + opts.MaxIdleConns = 100 + } + if opts.MaxIdleConnsPerHost <= 0 { + opts.MaxIdleConnsPerHost = 10 + } + client := opts.Client + if client == nil { + transport := opts.Transport + if transport == nil { + transport = &http.Transport{ + MaxIdleConns: opts.MaxIdleConns, + MaxIdleConnsPerHost: opts.MaxIdleConnsPerHost, + IdleConnTimeout: opts.IdleConnTimeout, + TLSHandshakeTimeout: opts.TLSHandshakeTimeout, + ResponseHeaderTimeout: opts.ResponseHeaderTimeout, + } + } + client = &http.Client{ + Transport: transport, + Timeout: opts.Timeout, + } + } + return &HTTPClient{ + client: client, + logger: opts.Logger.With(slog.String("component", "wecom_http_client")), + } +} + +func (c *HTTPClient) DownloadFile(ctx context.Context, rawURL string) (DownloadedFile, error) { + u := strings.TrimSpace(rawURL) + if u == "" { + return DownloadedFile{}, fmt.Errorf("download url is required") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return DownloadedFile{}, err + } + resp, err := c.client.Do(req) + if err != nil { + return DownloadedFile{}, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return DownloadedFile{}, fmt.Errorf("download failed with status %d", resp.StatusCode) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return DownloadedFile{}, err + } + fileName := extractFilename(resp.Header.Get("Content-Disposition"), u) + return DownloadedFile{ + Data: data, + FileName: fileName, + ContentType: strings.TrimSpace(resp.Header.Get("Content-Type")), + }, nil +} + +func (c *HTTPClient) DownloadAndDecryptFile(ctx context.Context, rawURL string, aesKey string) (DownloadedFile, error) { + file, err := c.DownloadFile(ctx, rawURL) + if err != nil { + return DownloadedFile{}, err + } + plain, err := DecryptFileAES256CBC(file.Data, aesKey) + if err != nil { + return DownloadedFile{}, err + } + file.Data = plain + return file, nil +} + +func extractFilename(contentDisposition, rawURL string) string { + cd := strings.TrimSpace(contentDisposition) + if cd != "" { + _, params, err := mime.ParseMediaType(cd) + if err == nil { + if v := strings.TrimSpace(params["filename*"]); v != "" { + if decoded := decodeRFC5987Filename(v); decoded != "" { + return decoded + } + return v + } + if v := strings.TrimSpace(params["filename"]); v != "" { + return v + } + } + } + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return "" + } + base := strings.TrimSpace(path.Base(parsed.Path)) + if base == "." || base == "/" { + return "" + } + return base +} + +func decodeRFC5987Filename(value string) string { + parts := strings.SplitN(strings.TrimSpace(value), "''", 2) + encoded := strings.TrimSpace(value) + if len(parts) == 2 { + encoded = strings.TrimSpace(parts[1]) + } + if encoded == "" { + return "" + } + decoded, err := url.QueryUnescape(encoded) + if err != nil { + return "" + } + return strings.TrimSpace(decoded) +} diff --git a/internal/channel/adapters/wecom/http_client_test.go b/internal/channel/adapters/wecom/http_client_test.go new file mode 100644 index 00000000..428d335e --- /dev/null +++ b/internal/channel/adapters/wecom/http_client_test.go @@ -0,0 +1,36 @@ +package wecom + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestDownloadFile_ParsesFilename(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Disposition", "attachment; filename*=UTF-8''hello%20wecom.txt") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer srv.Close() + + client := NewHTTPClient(HTTPClientOptions{}) + file, err := client.DownloadFile(context.Background(), srv.URL+"/download.bin") + if err != nil { + t.Fatalf("DownloadFile error = %v", err) + } + if file.FileName != "hello wecom.txt" { + t.Fatalf("unexpected filename: got=%q", file.FileName) + } + if string(file.Data) != "ok" { + t.Fatalf("unexpected payload: got=%q", string(file.Data)) + } +} + +func TestExtractFilename_FallbackPath(t *testing.T) { + got := extractFilename("", "https://example.com/files/a.png") + if got != "a.png" { + t.Fatalf("unexpected filename fallback: got=%q", got) + } +} diff --git a/internal/channel/adapters/wecom/inbound.go b/internal/channel/adapters/wecom/inbound.go new file mode 100644 index 00000000..43ac2a60 --- /dev/null +++ b/internal/channel/adapters/wecom/inbound.go @@ -0,0 +1,338 @@ +package wecom + +import ( + "context" + "strings" + "time" + + "github.com/memohai/memoh/internal/channel" +) + +func (a *WeComAdapter) handleFrame(ctx context.Context, cfg channel.ChannelConfig, frame WSFrame, handler channel.InboundHandler) error { + switch frame.Cmd { + case WSCmdMsgCallback: + if handler == nil { + return nil + } + var body MessageCallbackBody + if err := frame.DecodeBody(&body); err != nil { + return err + } + msg, ok := buildInboundMessage(body, frame.Headers.ReqID) + if !ok { + return nil + } + a.rememberCallback(body.MsgID, frame.Headers.ReqID, body.ResponseURL, body.ChatID, body.From.UserID) + return handler(ctx, cfg, msg) + case WSCmdEventCallback: + if handler == nil { + return nil + } + var body EventCallbackBody + if err := frame.DecodeBody(&body); err != nil { + return err + } + msg, ok := buildInboundEventMessage(body, frame.Headers.ReqID) + if !ok { + return nil + } + a.rememberCallback(body.MsgID, frame.Headers.ReqID, body.ResponseURL, body.ChatID, body.From.UserID) + return handler(ctx, cfg, msg) + default: + return nil + } +} + +func buildInboundMessage(body MessageCallbackBody, reqID string) (channel.InboundMessage, bool) { + text, attachments := extractBodyContent(body) + if strings.TrimSpace(text) == "" && len(attachments) == 0 { + return channel.InboundMessage{}, false + } + target := resolveDeliveryTarget(body.ChatID, body.From.UserID) + if target == "" { + return channel.InboundMessage{}, false + } + convType := normalizeConversationType(body.ChatType) + convID := strings.TrimSpace(body.ChatID) + if convID == "" { + convID = strings.TrimSpace(body.From.UserID) + } + msg := channel.InboundMessage{ + Channel: Type, + Message: channel.Message{ + ID: strings.TrimSpace(body.MsgID), + Format: channel.MessageFormatPlain, + Text: strings.TrimSpace(text), + Attachments: attachments, + Metadata: map[string]any{ + "response_url": strings.TrimSpace(body.ResponseURL), + }, + }, + ReplyTarget: target, + Sender: channel.Identity{ + SubjectID: strings.TrimSpace(body.From.UserID), + Attributes: map[string]string{ + "user_id": strings.TrimSpace(body.From.UserID), + "chat_id": strings.TrimSpace(body.ChatID), + }, + }, + Conversation: channel.Conversation{ + ID: convID, + Type: convType, + }, + ReceivedAt: parseCreateTime(body.CreateTime), + Source: "wecom", + Metadata: map[string]any{ + "req_id": strings.TrimSpace(reqID), + "chat_id": strings.TrimSpace(body.ChatID), + "chat_type": strings.TrimSpace(body.ChatType), + "response_url": strings.TrimSpace(body.ResponseURL), + }, + } + return msg, true +} + +func buildInboundEventMessage(body EventCallbackBody, reqID string) (channel.InboundMessage, bool) { + eventType := normalizeEventType(body.Event.EventType, body.Event.EventType2) + if eventType == "" { + return channel.InboundMessage{}, false + } + target := resolveDeliveryTarget(body.ChatID, body.From.UserID) + if target == "" { + return channel.InboundMessage{}, false + } + convType := normalizeConversationType(body.ChatType) + convID := strings.TrimSpace(body.ChatID) + if convID == "" { + convID = strings.TrimSpace(body.From.UserID) + } + return channel.InboundMessage{ + Channel: Type, + Message: channel.Message{ + ID: strings.TrimSpace(body.MsgID), + Format: channel.MessageFormatPlain, + Text: eventType, + Metadata: map[string]any{ + "event_type": eventType, + "task_id": strings.TrimSpace(body.Event.TaskID), + "event_key": strings.TrimSpace(body.Event.EventKey), + "event_code": strings.TrimSpace(body.Event.Code), + "event_reason": strings.TrimSpace(body.Event.Reason), + "task_status": strings.TrimSpace(body.Task.TaskStatus), + "response_url": strings.TrimSpace(body.ResponseURL), + }, + }, + ReplyTarget: target, + Sender: channel.Identity{ + SubjectID: strings.TrimSpace(body.From.UserID), + Attributes: map[string]string{ + "user_id": strings.TrimSpace(body.From.UserID), + "chat_id": strings.TrimSpace(body.ChatID), + }, + }, + Conversation: channel.Conversation{ + ID: convID, + Type: convType, + }, + ReceivedAt: parseCreateTime(body.CreateTime), + Source: "wecom", + Metadata: map[string]any{ + "is_event": true, + "event_type": eventType, + "event_key": strings.TrimSpace(body.Event.EventKey), + "event_code": strings.TrimSpace(body.Event.Code), + "event_reason": strings.TrimSpace(body.Event.Reason), + "req_id": strings.TrimSpace(reqID), + "response_url": strings.TrimSpace(body.ResponseURL), + }, + }, true +} + +func extractBodyContent(body MessageCallbackBody) (string, []channel.Attachment) { + switch strings.ToLower(strings.TrimSpace(body.MsgType)) { + case "text": + if body.Text == nil { + return "", nil + } + return strings.TrimSpace(body.Text.Content), nil + case "markdown": + if body.Markdown == nil { + return "", nil + } + return strings.TrimSpace(body.Markdown.Content), nil + case "voice": + if body.Voice == nil { + return "", nil + } + return strings.TrimSpace(body.Voice.Content), nil + case "video": + if body.Video == nil { + return "", nil + } + att := channel.Attachment{ + Type: channel.AttachmentVideo, + URL: strings.TrimSpace(body.Video.URL), + SourcePlatform: Type.String(), + Metadata: map[string]any{ + "aeskey": strings.TrimSpace(body.Video.AESKey), + }, + } + return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)} + case "image": + if body.Image == nil { + return "", nil + } + att := channel.Attachment{ + Type: channel.AttachmentImage, + URL: strings.TrimSpace(body.Image.URL), + SourcePlatform: Type.String(), + Metadata: map[string]any{ + "aeskey": strings.TrimSpace(body.Image.AESKey), + }, + } + return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)} + case "file": + if body.File == nil { + return "", nil + } + att := channel.Attachment{ + Type: channel.AttachmentFile, + URL: strings.TrimSpace(body.File.URL), + SourcePlatform: Type.String(), + Name: strings.TrimSpace(body.File.FileName), + Metadata: map[string]any{ + "aeskey": strings.TrimSpace(body.File.AESKey), + }, + } + return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)} + case "mixed": + var textParts []string + attachments := make([]channel.Attachment, 0) + for _, item := range body.Mixed { + switch strings.ToLower(strings.TrimSpace(item.MsgType)) { + case "text": + if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" { + textParts = append(textParts, strings.TrimSpace(item.Text.Content)) + } + case "markdown": + if item.Markdown != nil && strings.TrimSpace(item.Markdown.Content) != "" { + textParts = append(textParts, strings.TrimSpace(item.Markdown.Content)) + } + case "image": + if item.Image == nil { + continue + } + att := channel.Attachment{ + Type: channel.AttachmentImage, + URL: strings.TrimSpace(item.Image.URL), + SourcePlatform: Type.String(), + Metadata: map[string]any{ + "aeskey": strings.TrimSpace(item.Image.AESKey), + }, + } + attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att)) + case "file": + if item.File == nil { + continue + } + att := channel.Attachment{ + Type: channel.AttachmentFile, + URL: strings.TrimSpace(item.File.URL), + SourcePlatform: Type.String(), + Name: strings.TrimSpace(item.File.FileName), + Metadata: map[string]any{ + "aeskey": strings.TrimSpace(item.File.AESKey), + }, + } + attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att)) + case "video": + if item.Video == nil { + continue + } + att := channel.Attachment{ + Type: channel.AttachmentVideo, + URL: strings.TrimSpace(item.Video.URL), + SourcePlatform: Type.String(), + Metadata: map[string]any{ + "aeskey": strings.TrimSpace(item.Video.AESKey), + }, + } + attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att)) + case "voice": + if item.Voice != nil && strings.TrimSpace(item.Voice.Content) != "" { + textParts = append(textParts, strings.TrimSpace(item.Voice.Content)) + } + } + } + return strings.Join(textParts, "\n"), attachments + default: + return "", nil + } +} + +func resolveDeliveryTarget(chatID, userID string) string { + if v := strings.TrimSpace(chatID); v != "" { + return "chat_id:" + v + } + if v := strings.TrimSpace(userID); v != "" { + return "user_id:" + v + } + return "" +} + +func normalizeConversationType(chatType string) string { + if strings.EqualFold(strings.TrimSpace(chatType), "group") { + return "group" + } + return "private" +} + +func parseCreateTime(ts int64) time.Time { + if t := unixMilliseconds(ts); !t.IsZero() { + return t + } + return time.Now().UTC() +} + +func (a *WeComAdapter) rememberCallback(msgID, reqID, responseURL, chatID, userID string) { + if a == nil || a.cache == nil { + return + } + msgID = strings.TrimSpace(msgID) + if msgID == "" { + return + } + if strings.TrimSpace(reqID) == "" { + return + } + a.cache.Put(msgID, callbackContext{ + ReqID: strings.TrimSpace(reqID), + ResponseURL: strings.TrimSpace(responseURL), + ChatID: strings.TrimSpace(chatID), + UserID: strings.TrimSpace(userID), + CreatedAt: time.Now().UTC(), + }) +} + +func normalizeEventType(values ...string) string { + candidates := make([]string, 0, len(values)) + for _, v := range values { + v = strings.TrimSpace(strings.ToLower(v)) + if v == "" { + continue + } + v = strings.ReplaceAll(v, "-", "_") + v = strings.ReplaceAll(v, " ", "_") + candidates = append(candidates, v) + } + for _, v := range candidates { + switch v { + case "enter_chat", "template_card_event", "feedback_event": + return v + } + } + if len(candidates) == 0 { + return "" + } + return candidates[0] +} diff --git a/internal/channel/adapters/wecom/inbound_test.go b/internal/channel/adapters/wecom/inbound_test.go new file mode 100644 index 00000000..fd2ab75a --- /dev/null +++ b/internal/channel/adapters/wecom/inbound_test.go @@ -0,0 +1,103 @@ +package wecom + +import "testing" + +func TestBuildInboundMessage_Text(t *testing.T) { + msg, ok := buildInboundMessage(MessageCallbackBody{ + MsgID: "m1", + ChatID: "chat-1", + ChatType: "group", + From: CallbackFrom{ + UserID: "u1", + }, + MsgType: "text", + Text: &MessageText{ + Content: "hello", + }, + }, "req-1") + if !ok { + t.Fatal("expected inbound message") + } + if msg.Message.Text != "hello" { + t.Fatalf("unexpected text: %q", msg.Message.Text) + } + if msg.ReplyTarget != "chat_id:chat-1" { + t.Fatalf("unexpected target: %q", msg.ReplyTarget) + } + if msg.Metadata["req_id"] != "req-1" { + t.Fatalf("unexpected req_id: %v", msg.Metadata["req_id"]) + } +} + +func TestBuildInboundMessage_Markdown(t *testing.T) { + msg, ok := buildInboundMessage(MessageCallbackBody{ + MsgID: "m2", + ChatID: "chat-2", + ChatType: "private", + From: CallbackFrom{UserID: "u2"}, + MsgType: "markdown", + Markdown: &MessageText{ + Content: "**hello**", + }, + }, "req-2") + if !ok { + t.Fatal("expected inbound markdown message") + } + if msg.Message.Text != "**hello**" { + t.Fatalf("unexpected markdown text: %q", msg.Message.Text) + } +} + +func TestBuildInboundMessage_Video(t *testing.T) { + msg, ok := buildInboundMessage(MessageCallbackBody{ + MsgID: "m3", + ChatID: "chat-3", + ChatType: "group", + From: CallbackFrom{UserID: "u3"}, + MsgType: "video", + Video: &MessageVideo{ + URL: "https://example.com/v.mp4", + AESKey: "k", + }, + }, "req-3") + if !ok { + t.Fatal("expected inbound video message") + } + if len(msg.Message.Attachments) != 1 { + t.Fatalf("expected one attachment, got %d", len(msg.Message.Attachments)) + } + if msg.Message.Attachments[0].Type != "video" { + t.Fatalf("unexpected attachment type: %s", msg.Message.Attachments[0].Type) + } +} + +func TestBuildInboundEventMessage_NormalizeEventType(t *testing.T) { + msg, ok := buildInboundEventMessage(EventCallbackBody{ + MsgID: "e1", + ChatID: "chat-1", + ChatType: "group", + From: CallbackFrom{UserID: "u1"}, + CreateTime: 1, + MsgType: "event", + Event: EventPayload{ + EventType2: "template-card-event", + EventKey: "btn_1", + TaskID: "task_1", + }, + Task: EventTask{ + TaskStatus: "done", + }, + }, "req-e1") + if !ok { + t.Fatal("expected event message") + } + if msg.Message.Text != "template_card_event" { + t.Fatalf("unexpected normalized event text: %q", msg.Message.Text) + } + if msg.Message.Metadata["event_key"] != "btn_1" { + t.Fatalf("unexpected event key: %v", msg.Message.Metadata["event_key"]) + } + if msg.Message.Metadata["task_status"] != "done" { + t.Fatalf("unexpected task status: %v", msg.Message.Metadata["task_status"]) + } +} diff --git a/internal/channel/adapters/wecom/outbound.go b/internal/channel/adapters/wecom/outbound.go new file mode 100644 index 00000000..6ac1d7c4 --- /dev/null +++ b/internal/channel/adapters/wecom/outbound.go @@ -0,0 +1,309 @@ +package wecom + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/base64" + "fmt" + "io" + "strings" + + "github.com/memohai/memoh/internal/channel" +) + +func (a *WeComAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) { + _ = cfg + if a.http == nil { + return channel.AttachmentPayload{}, fmt.Errorf("wecom http client not configured") + } + url := strings.TrimSpace(attachment.URL) + if url == "" { + return channel.AttachmentPayload{}, fmt.Errorf("wecom attachment url is required") + } + aesKey := "" + if attachment.Metadata != nil { + if value, ok := attachment.Metadata["aeskey"].(string); ok { + aesKey = strings.TrimSpace(value) + } + } + var file DownloadedFile + var err error + if aesKey != "" { + file, err = a.http.DownloadAndDecryptFile(ctx, url, aesKey) + } else { + file, err = a.http.DownloadFile(ctx, url) + } + if err != nil { + return channel.AttachmentPayload{}, err + } + return channel.AttachmentPayload{ + Reader: io.NopCloser(bytes.NewReader(file.Data)), + Mime: strings.TrimSpace(file.ContentType), + Name: strings.TrimSpace(file.FileName), + Size: int64(len(file.Data)), + }, nil +} + +type markdownPayload struct { + Content string `json:"content"` +} + +func buildSendPayload(msg channel.Message, targetID string) (any, string, string, error) { + if strings.TrimSpace(targetID) == "" { + return nil, "", "", fmt.Errorf("wecom target id is required") + } + reqID := NewReqID(WSCmdSendMessage) + if card, ok := readTemplateCard(msg.Metadata); ok { + return SendMessageTemplateCardBody{ + ChatID: targetID, + MsgType: "template_card", + TemplateCard: card, + }, WSCmdSendMessage, reqID, nil + } + + // aibot_send_msg currently supports markdown/template_card in official SDK. + // Attachments should be sent through callback-reply path (aibot_respond_msg). + if len(msg.Attachments) > 0 { + return nil, "", "", fmt.Errorf("wecom proactive send does not support attachments; use reply flow") + } + + text := strings.TrimSpace(msg.PlainText()) + if text == "" { + return nil, "", "", fmt.Errorf("wecom outbound text is required") + } + return SendMessageMarkdownBody{ + ChatID: targetID, + MsgType: "markdown", + Markdown: markdownPayload{ + Content: text, + }, + }, WSCmdSendMessage, reqID, nil +} + +func buildRespondPayload(msg channel.Message, replyReqID string) (any, string, string, error) { + return buildRespondPayloadWithStream(msg, replyReqID, "", true) +} + +func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, streamID string, finish bool) (any, string, string, error) { + reqID := strings.TrimSpace(replyReqID) + if reqID == "" { + return nil, "", "", fmt.Errorf("reply req_id is required") + } + if finish { + if body, ok := buildWelcomePayload(msg); ok { + return body, WSCmdRespondWelcome, reqID, nil + } + if body, ok := readUpdateTemplateCard(msg.Metadata); ok { + return body, WSCmdRespondUpdate, reqID, nil + } + } + text := strings.TrimSpace(msg.PlainText()) + if finish && text == "" && len(msg.Attachments) == 0 { + return nil, "", "", fmt.Errorf("wecom reply payload is empty") + } + if !finish && text == "" { + return nil, "", "", fmt.Errorf("wecom stream delta content is empty") + } + streamID = strings.TrimSpace(streamID) + if streamID == "" { + streamID = NewReqID("stream") + } + stream := StreamReplyBlock{ + ID: streamID, + Finish: finish, + Content: text, + } + if feedbackID := readFeedbackID(msg.Metadata); feedbackID != "" { + stream.Feedback = &StreamReplyFeedback{ID: feedbackID} + } + if finish && len(msg.Attachments) > 0 { + first := msg.Attachments[0] + base64Data := extractBase64Content(first.Base64) + if base64Data != "" { + raw, err := base64.StdEncoding.DecodeString(base64Data) + if err == nil && len(raw) > 0 { + stream.MsgItems = []StreamReplyItem{ + { + MsgType: "image", + Image: &StreamReplyImage{ + Base64: base64.StdEncoding.EncodeToString(raw), + MD5: fmt.Sprintf("%x", md5.Sum(raw)), + }, + }, + } + } + } + } + if card, ok := readTemplateCard(msg.Metadata); ok { + return StreamWithTemplateCardReplyBody{ + MsgType: "stream_with_template_card", + Stream: stream, + TemplateCard: card, + }, WSCmdRespond, reqID, nil + } + return StreamReplyBody{ + MsgType: "stream", + Stream: stream, + }, WSCmdRespond, reqID, nil +} + +func buildWelcomePayload(msg channel.Message) (any, bool) { + if !readBool(msg.Metadata, "wecom_welcome") { + return nil, false + } + if card, ok := readTemplateCard(msg.Metadata); ok { + return WelcomeTemplateCardReplyBody{ + MsgType: "template_card", + TemplateCard: card, + }, true + } + text := strings.TrimSpace(msg.PlainText()) + if text == "" { + return nil, false + } + return WelcomeTextReplyBody{ + MsgType: "text", + Text: welcomeTextBody{ + Content: text, + }, + }, true +} + +func readTemplateCard(metadata map[string]any) (map[string]any, bool) { + if metadata == nil { + return nil, false + } + raw, ok := metadata["wecom_template_card"] + if !ok { + return nil, false + } + card, ok := raw.(map[string]any) + if !ok || len(card) == 0 { + return nil, false + } + return card, true +} + +func readUpdateTemplateCard(metadata map[string]any) (UpdateTemplateCardBody, bool) { + if metadata == nil { + return UpdateTemplateCardBody{}, false + } + raw, ok := metadata["wecom_update_template_card"] + if !ok { + return UpdateTemplateCardBody{}, false + } + card, ok := raw.(map[string]any) + if !ok || len(card) == 0 { + return UpdateTemplateCardBody{}, false + } + body := UpdateTemplateCardBody{ + ResponseType: "update_template_card", + TemplateCard: card, + } + if userIDs := readStringSlice(metadata["wecom_update_userids"]); len(userIDs) > 0 { + body.UserIDs = userIDs + } + return body, true +} + +func readStringSlice(raw any) []string { + switch v := raw.(type) { + case []string: + out := make([]string, 0, len(v)) + for _, item := range v { + if s := strings.TrimSpace(item); s != "" { + out = append(out, s) + } + } + return out + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok && strings.TrimSpace(s) != "" { + out = append(out, strings.TrimSpace(s)) + } + } + return out + case string: + if strings.TrimSpace(v) == "" { + return nil + } + parts := strings.Split(v, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if s := strings.TrimSpace(part); s != "" { + out = append(out, s) + } + } + return out + default: + return nil + } +} + +func readFeedbackID(metadata map[string]any) string { + if metadata == nil { + return "" + } + raw, ok := metadata["wecom_feedback_id"] + if !ok { + return "" + } + if v, ok := raw.(string); ok { + return strings.TrimSpace(v) + } + return "" +} + +func readBool(metadata map[string]any, key string) bool { + if metadata == nil || strings.TrimSpace(key) == "" { + return false + } + raw, ok := metadata[key] + if !ok { + return false + } + v, ok := raw.(bool) + return ok && v +} + +func extractBase64Content(v string) string { + value := strings.TrimSpace(v) + if value == "" { + return "" + } + if idx := strings.Index(value, ","); idx > 0 && strings.Contains(strings.ToLower(value[:idx]), "base64") { + return strings.TrimSpace(value[idx+1:]) + } + return value +} + +func (a *WeComAdapter) getClient(botID string) *WSClient { + key := strings.TrimSpace(botID) + if key == "" { + return nil + } + a.mu.RLock() + defer a.mu.RUnlock() + return a.clients[key] +} + +func (a *WeComAdapter) lookupCallbackContext(reply *channel.ReplyRef) (callbackContext, bool) { + if a == nil || a.cache == nil || reply == nil { + return callbackContext{}, false + } + messageID := strings.TrimSpace(reply.MessageID) + if messageID == "" { + return callbackContext{}, false + } + return a.cache.Get(messageID) +} + +func (a *WeComAdapter) ensureHTTPClient() { + a.mu.Lock() + defer a.mu.Unlock() + if a.http == nil { + a.http = NewHTTPClient(HTTPClientOptions{Logger: a.logger}) + } +} diff --git a/internal/channel/adapters/wecom/outbound_test.go b/internal/channel/adapters/wecom/outbound_test.go new file mode 100644 index 00000000..82a80772 --- /dev/null +++ b/internal/channel/adapters/wecom/outbound_test.go @@ -0,0 +1,239 @@ +package wecom + +import ( + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +func TestBuildSendPayload_Text(t *testing.T) { + payload, cmd, reqID, err := buildSendPayload(channel.Message{ + Format: channel.MessageFormatPlain, + Text: "hello", + }, "chat_1") + if err != nil { + t.Fatalf("buildSendPayload error = %v", err) + } + if cmd != WSCmdSendMessage { + t.Fatalf("unexpected cmd: %q", cmd) + } + if reqID == "" { + t.Fatal("reqID should not be empty") + } + p, ok := payload.(SendMessageMarkdownBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.Markdown.Content != "hello" { + t.Fatalf("unexpected payload content: %q", p.Markdown.Content) + } +} + +func TestBuildSendPayload_AttachmentNotSupported(t *testing.T) { + _, _, _, err := buildSendPayload(channel.Message{ + Attachments: []channel.Attachment{ + {Type: channel.AttachmentImage, Base64: "aGVsbG8="}, + }, + }, "chat_1") + if err == nil { + t.Fatal("expected error for proactive attachment send") + } +} + +func TestBuildRespondPayload_Stream(t *testing.T) { + payload, cmd, reqID, err := buildRespondPayload(channel.Message{ + Format: channel.MessageFormatMarkdown, + Text: "world", + }, "req_abc") + if err != nil { + t.Fatalf("buildRespondPayload error = %v", err) + } + if cmd != WSCmdRespond { + t.Fatalf("unexpected cmd: %q", cmd) + } + if reqID != "req_abc" { + t.Fatalf("unexpected req id: %q", reqID) + } + p, ok := payload.(StreamReplyBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.MsgType != "stream" || p.Stream.Content != "world" || !p.Stream.Finish { + t.Fatalf("unexpected stream payload: %+v", p) + } +} + +func TestBuildRespondPayload_StreamWithFeedback(t *testing.T) { + payload, cmd, _, err := buildRespondPayload(channel.Message{ + Text: "world", + Metadata: map[string]any{ + "wecom_feedback_id": "feedback_1", + }, + }, "req_abc") + if err != nil { + t.Fatalf("buildRespondPayload error = %v", err) + } + if cmd != WSCmdRespond { + t.Fatalf("unexpected cmd: %q", cmd) + } + p, ok := payload.(StreamReplyBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.Stream.Feedback == nil || p.Stream.Feedback.ID != "feedback_1" { + t.Fatalf("unexpected feedback: %+v", p.Stream.Feedback) + } +} + +func TestBuildRespondPayloadWithStream_Delta(t *testing.T) { + payload, cmd, reqID, err := buildRespondPayloadWithStream(channel.Message{ + Text: "delta", + }, "req_abc", "stream_1", false) + if err != nil { + t.Fatalf("buildRespondPayloadWithStream error = %v", err) + } + if cmd != WSCmdRespond || reqID != "req_abc" { + t.Fatalf("unexpected cmd/reqid: %q %q", cmd, reqID) + } + p, ok := payload.(StreamReplyBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.Stream.ID != "stream_1" || p.Stream.Finish { + t.Fatalf("unexpected stream payload: %+v", p.Stream) + } +} + +func TestBuildRespondPayloadWithStream_EmptyDeltaRejected(t *testing.T) { + _, _, _, err := buildRespondPayloadWithStream(channel.Message{}, "req_abc", "stream_1", false) + if err == nil { + t.Fatal("expected error for empty delta content") + } +} + +func TestBuildSendPayload_TemplateCard(t *testing.T) { + payload, cmd, _, err := buildSendPayload(channel.Message{ + Metadata: map[string]any{ + "wecom_template_card": map[string]any{ + "card_type": "text_notice", + }, + }, + }, "chat_1") + if err != nil { + t.Fatalf("buildSendPayload error = %v", err) + } + if cmd != WSCmdSendMessage { + t.Fatalf("unexpected cmd: %q", cmd) + } + p, ok := payload.(SendMessageTemplateCardBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.MsgType != "template_card" { + t.Fatalf("unexpected msg type: %q", p.MsgType) + } +} + +func TestBuildRespondPayload_StreamWithTemplateCard(t *testing.T) { + payload, cmd, _, err := buildRespondPayload(channel.Message{ + Text: "x", + Metadata: map[string]any{ + "wecom_template_card": map[string]any{ + "card_type": "text_notice", + }, + }, + }, "req_abc") + if err != nil { + t.Fatalf("buildRespondPayload error = %v", err) + } + if cmd != WSCmdRespond { + t.Fatalf("unexpected cmd: %q", cmd) + } + p, ok := payload.(StreamWithTemplateCardReplyBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.MsgType != "stream_with_template_card" { + t.Fatalf("unexpected msg type: %q", p.MsgType) + } +} + +func TestBuildRespondPayload_UpdateTemplateCard(t *testing.T) { + payload, cmd, reqID, err := buildRespondPayload(channel.Message{ + Metadata: map[string]any{ + "wecom_update_template_card": map[string]any{ + "card_type": "text_notice", + "task_id": "task-1", + }, + "wecom_update_userids": []any{"u1", "u2"}, + }, + }, "req_abc") + if err != nil { + t.Fatalf("buildRespondPayload error = %v", err) + } + if cmd != WSCmdRespondUpdate { + t.Fatalf("unexpected cmd: %q", cmd) + } + if reqID != "req_abc" { + t.Fatalf("unexpected req id: %q", reqID) + } + p, ok := payload.(UpdateTemplateCardBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.ResponseType != "update_template_card" { + t.Fatalf("unexpected response type: %q", p.ResponseType) + } + if len(p.UserIDs) != 2 || p.UserIDs[0] != "u1" || p.UserIDs[1] != "u2" { + t.Fatalf("unexpected userids: %#v", p.UserIDs) + } +} + +func TestBuildRespondPayload_WelcomeText(t *testing.T) { + payload, cmd, reqID, err := buildRespondPayload(channel.Message{ + Text: "welcome", + Metadata: map[string]any{ + "wecom_welcome": true, + }, + }, "req_abc") + if err != nil { + t.Fatalf("buildRespondPayload error = %v", err) + } + if cmd != WSCmdRespondWelcome { + t.Fatalf("unexpected cmd: %q", cmd) + } + if reqID != "req_abc" { + t.Fatalf("unexpected req id: %q", reqID) + } + p, ok := payload.(WelcomeTextReplyBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.MsgType != "text" || p.Text.Content != "welcome" { + t.Fatalf("unexpected welcome payload: %+v", p) + } +} + +func TestBuildRespondPayload_WelcomeTemplateCard(t *testing.T) { + payload, cmd, _, err := buildRespondPayload(channel.Message{ + Metadata: map[string]any{ + "wecom_welcome": true, + "wecom_template_card": map[string]any{ + "card_type": "text_notice", + }, + }, + }, "req_abc") + if err != nil { + t.Fatalf("buildRespondPayload error = %v", err) + } + if cmd != WSCmdRespondWelcome { + t.Fatalf("unexpected cmd: %q", cmd) + } + p, ok := payload.(WelcomeTemplateCardReplyBody) + if !ok { + t.Fatalf("unexpected payload type: %T", payload) + } + if p.MsgType != "template_card" { + t.Fatalf("unexpected msg type: %q", p.MsgType) + } +} diff --git a/internal/channel/adapters/wecom/protocol.go b/internal/channel/adapters/wecom/protocol.go new file mode 100644 index 00000000..69e2ca9b --- /dev/null +++ b/internal/channel/adapters/wecom/protocol.go @@ -0,0 +1,256 @@ +package wecom + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" +) + +const ( + defaultWSURL = "wss://openws.work.weixin.qq.com" +) + +const ( + WSCmdSubscribe = "aibot_subscribe" + WSCmdHeartbeat = "ping" + WSCmdRespond = "aibot_respond_msg" + WSCmdRespondWelcome = "aibot_respond_welcome_msg" + WSCmdRespondUpdate = "aibot_respond_update_msg" + WSCmdSendMessage = "aibot_send_msg" + WSCmdMsgCallback = "aibot_msg_callback" + WSCmdEventCallback = "aibot_event_callback" +) + +type WSHeaders struct { + ReqID string `json:"req_id"` +} + +type WSFrame struct { + Cmd string `json:"cmd,omitempty"` + Headers WSHeaders `json:"headers"` + Body json.RawMessage `json:"body,omitempty"` + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` +} + +func (f WSFrame) DecodeBody(dst any) error { + if len(f.Body) == 0 { + return fmt.Errorf("wecom frame body is empty") + } + if dst == nil { + return fmt.Errorf("decode target is nil") + } + return json.Unmarshal(f.Body, dst) +} + +func BuildFrame(cmd, reqID string, body any) (WSFrame, error) { + frame := WSFrame{ + Cmd: cmd, + Headers: WSHeaders{ + ReqID: strings.TrimSpace(reqID), + }, + } + if frame.Headers.ReqID == "" { + return WSFrame{}, fmt.Errorf("req_id is required") + } + if body == nil { + return frame, nil + } + raw, err := json.Marshal(body) + if err != nil { + return WSFrame{}, err + } + frame.Body = raw + return frame, nil +} + +func NewReqID(prefix string) string { + p := strings.TrimSpace(prefix) + if p == "" { + return uuid.NewString() + } + return p + "_" + uuid.NewString() +} + +type AuthCredentials struct { + BotID string + Secret string +} + +func (c AuthCredentials) Validate() error { + if strings.TrimSpace(c.BotID) == "" { + return fmt.Errorf("wecom bot_id is required") + } + if strings.TrimSpace(c.Secret) == "" { + return fmt.Errorf("wecom secret is required") + } + return nil +} + +type SubscribeBody struct { + BotID string `json:"bot_id"` + Secret string `json:"secret"` +} + +type CallbackFrom struct { + UserID string `json:"userid,omitempty"` +} + +type MessageText struct { + Content string `json:"content,omitempty"` +} + +type MessageImage struct { + URL string `json:"url,omitempty"` + AESKey string `json:"aeskey,omitempty"` +} + +type MessageFile struct { + URL string `json:"url,omitempty"` + AESKey string `json:"aeskey,omitempty"` + FileName string `json:"file_name,omitempty"` +} + +type MessageVoice struct { + Content string `json:"content,omitempty"` +} + +type MessageVideo struct { + URL string `json:"url,omitempty"` + AESKey string `json:"aeskey,omitempty"` +} + +type MessageMixedItem struct { + MsgType string `json:"msgtype,omitempty"` + Text *MessageText `json:"text,omitempty"` + Markdown *MessageText `json:"markdown,omitempty"` + Image *MessageImage `json:"image,omitempty"` + File *MessageFile `json:"file,omitempty"` + Voice *MessageVoice `json:"voice,omitempty"` + Video *MessageVideo `json:"video,omitempty"` +} + +type MessageQuote struct { + MsgID string `json:"msgid,omitempty"` +} + +type MessageCallbackBody struct { + MsgID string `json:"msgid,omitempty"` + AIBotID string `json:"aibotid,omitempty"` + ChatID string `json:"chatid,omitempty"` + ChatType string `json:"chattype,omitempty"` + From CallbackFrom `json:"from,omitempty"` + CreateTime int64 `json:"create_time,omitempty"` + MsgType string `json:"msgtype,omitempty"` + ResponseURL string `json:"response_url,omitempty"` + Text *MessageText `json:"text,omitempty"` + Markdown *MessageText `json:"markdown,omitempty"` + Image *MessageImage `json:"image,omitempty"` + File *MessageFile `json:"file,omitempty"` + Voice *MessageVoice `json:"voice,omitempty"` + Video *MessageVideo `json:"video,omitempty"` + Mixed []MessageMixedItem `json:"mixed,omitempty"` + Quote *MessageQuote `json:"quote,omitempty"` +} + +type EventPayload struct { + EventType string `json:"event_type,omitempty"` + EventType2 string `json:"eventtype,omitempty"` + EventKey string `json:"event_key,omitempty"` + TaskID string `json:"task_id,omitempty"` + Code string `json:"code,omitempty"` + Reason string `json:"reason,omitempty"` +} + +type EventTask struct { + TaskID string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` +} + +type EventCallbackBody struct { + MsgID string `json:"msgid,omitempty"` + AIBotID string `json:"aibotid,omitempty"` + ChatID string `json:"chatid,omitempty"` + ChatType string `json:"chattype,omitempty"` + From CallbackFrom `json:"from,omitempty"` + CreateTime int64 `json:"create_time,omitempty"` + MsgType string `json:"msgtype,omitempty"` + ResponseURL string `json:"response_url,omitempty"` + Event EventPayload `json:"event,omitempty"` + Task EventTask `json:"task,omitempty"` +} + +type StreamReplyBody struct { + MsgType string `json:"msgtype"` + Stream StreamReplyBlock `json:"stream"` +} + +type StreamReplyBlock struct { + ID string `json:"id"` + Finish bool `json:"finish,omitempty"` + Content string `json:"content,omitempty"` + MsgItems []StreamReplyItem `json:"msg_item,omitempty"` + Feedback *StreamReplyFeedback `json:"feedback,omitempty"` +} + +type StreamReplyItem struct { + MsgType string `json:"msgtype"` + Image *StreamReplyImage `json:"image,omitempty"` +} + +type StreamReplyImage struct { + Base64 string `json:"base64"` + MD5 string `json:"md5"` +} + +type StreamReplyFeedback struct { + ID string `json:"id"` +} + +type SendMessageMarkdownBody struct { + ChatID string `json:"chatid"` + MsgType string `json:"msgtype"` + Markdown markdownPayload `json:"markdown"` +} + +type SendMessageTemplateCardBody struct { + ChatID string `json:"chatid"` + MsgType string `json:"msgtype"` + TemplateCard map[string]any `json:"template_card"` +} + +type StreamWithTemplateCardReplyBody struct { + MsgType string `json:"msgtype"` + Stream StreamReplyBlock `json:"stream"` + TemplateCard map[string]any `json:"template_card"` +} + +type WelcomeTextReplyBody struct { + MsgType string `json:"msgtype"` + Text welcomeTextBody `json:"text"` +} + +type welcomeTextBody struct { + Content string `json:"content"` +} + +type WelcomeTemplateCardReplyBody struct { + MsgType string `json:"msgtype"` + TemplateCard map[string]any `json:"template_card"` +} + +type UpdateTemplateCardBody struct { + ResponseType string `json:"response_type"` + UserIDs []string `json:"userids,omitempty"` + TemplateCard map[string]any `json:"template_card"` +} + +func unixMilliseconds(ts int64) time.Time { + if ts <= 0 { + return time.Time{} + } + return time.UnixMilli(ts) +} diff --git a/internal/channel/adapters/wecom/protocol_test.go b/internal/channel/adapters/wecom/protocol_test.go new file mode 100644 index 00000000..a002e462 --- /dev/null +++ b/internal/channel/adapters/wecom/protocol_test.go @@ -0,0 +1,29 @@ +package wecom + +import "testing" + +func TestBuildFrameAndDecodeBody(t *testing.T) { + type body struct { + BotID string `json:"bot_id"` + } + frame, err := BuildFrame(WSCmdSubscribe, "req-1", body{BotID: "bot123"}) + if err != nil { + t.Fatalf("BuildFrame error = %v", err) + } + var decoded body + if err := frame.DecodeBody(&decoded); err != nil { + t.Fatalf("DecodeBody error = %v", err) + } + if decoded.BotID != "bot123" { + t.Fatalf("unexpected decoded body: %+v", decoded) + } +} + +func TestAuthCredentialsValidate(t *testing.T) { + if err := (AuthCredentials{}).Validate(); err == nil { + t.Fatal("expected validation error for empty credentials") + } + if err := (AuthCredentials{BotID: "id", Secret: "sec"}).Validate(); err != nil { + t.Fatalf("unexpected validation error: %v", err) + } +} diff --git a/internal/channel/adapters/wecom/wecom.go b/internal/channel/adapters/wecom/wecom.go new file mode 100644 index 00000000..81df25a7 --- /dev/null +++ b/internal/channel/adapters/wecom/wecom.go @@ -0,0 +1,434 @@ +package wecom + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/memohai/memoh/internal/channel" +) + +const Type channel.ChannelType = "wecom" + +type wsClientFactory func(opts WSClientOptions) *WSClient + +type WeComAdapter struct { + logger *slog.Logger + + mu sync.RWMutex + clients map[string]*WSClient + http *HTTPClient + cache *callbackContextCache + + newWSClient wsClientFactory +} + +func NewWeComAdapter(log *slog.Logger) *WeComAdapter { + if log == nil { + log = slog.Default() + } + return &WeComAdapter{ + logger: log.With(slog.String("adapter", "wecom")), + clients: make(map[string]*WSClient), + http: NewHTTPClient(HTTPClientOptions{Logger: log}), + cache: newCallbackContextCache(24 * time.Hour), + newWSClient: func(opts WSClientOptions) *WSClient { return NewWSClient(opts) }, + } +} + +func (a *WeComAdapter) Type() channel.ChannelType { return Type } + +func (a *WeComAdapter) Descriptor() channel.Descriptor { + return channel.Descriptor{ + Type: Type, + DisplayName: "WeCom", + Capabilities: channel.ChannelCapabilities{ + Text: true, + Markdown: true, + Attachments: true, + Media: true, + Reply: true, + Streaming: true, + BlockStreaming: true, + ChatTypes: []string{"private", "group"}, + }, + ConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "botId": {Type: channel.FieldString, Required: true, Title: "Bot ID"}, + "secret": {Type: channel.FieldSecret, Required: true, Title: "Secret"}, + "wsUrl": {Type: channel.FieldString, Title: "WebSocket URL", Example: defaultWSURL}, + "heartbeatSeconds": {Type: channel.FieldNumber, Title: "Heartbeat Seconds"}, + "ackTimeoutSeconds": {Type: channel.FieldNumber, Title: "Ack Timeout Seconds"}, + "writeTimeoutSeconds": {Type: channel.FieldNumber, Title: "Write Timeout Seconds"}, + "readTimeoutSeconds": {Type: channel.FieldNumber, Title: "Read Timeout Seconds"}, + }, + }, + UserConfigSchema: channel.ConfigSchema{ + Version: 1, + Fields: map[string]channel.FieldSchema{ + "chat_id": {Type: channel.FieldString}, + "user_id": {Type: channel.FieldString}, + }, + }, + TargetSpec: channel.TargetSpec{ + Format: "chat_id:xxx | user_id:xxx", + Hints: []channel.TargetHint{ + {Label: "Chat ID", Example: "chat_id:wrk_abc"}, + {Label: "User ID", Example: "user_id:zhangsan"}, + }, + }, + } +} + +func (a *WeComAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { + return normalizeConfig(raw) +} + +func (a *WeComAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { + return normalizeUserConfig(raw) +} + +func (a *WeComAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } + +func (a *WeComAdapter) ResolveTarget(userConfig map[string]any) (string, error) { + return resolveTarget(userConfig) +} + +func (a *WeComAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { + return matchBinding(config, criteria) +} + +func (a *WeComAdapter) BuildUserConfig(identity channel.Identity) map[string]any { + return buildUserConfig(identity) +} + +func (a *WeComAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { + _ = ctx + cfg, err := parseConfig(credentials) + if err != nil { + return nil, "", err + } + externalID := strings.TrimSpace(cfg.BotID) + identity := map[string]any{ + "bot_id": externalID, + "aibot_id": externalID, + } + return identity, externalID, nil +} + +func (a *WeComAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { + parsed, err := parseConfig(cfg.Credentials) + if err != nil { + return nil, err + } + client := a.newWSClient(WSClientOptions{ + URL: parsed.WSURL, + Logger: a.logger, + HeartbeatInterval: time.Duration(secondsOrDefault(parsed.HeartbeatSeconds, 30)) * time.Second, + AckTimeout: time.Duration(secondsOrDefault(parsed.AckTimeoutSeconds, 8)) * time.Second, + WriteTimeout: time.Duration(secondsOrDefault(parsed.WriteTimeoutSeconds, 8)) * time.Second, + ReadTimeout: time.Duration(secondsOrDefault(parsed.ReadTimeoutSeconds, 70)) * time.Second, + ReconnectBaseDelay: 1 * time.Second, + ReconnectMaxDelay: 30 * time.Second, + }) + + key := strings.TrimSpace(parsed.BotID) + a.mu.Lock() + a.clients[key] = client + a.mu.Unlock() + + connCtx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + go func() { + defer close(done) + err := client.Run(connCtx, AuthCredentials{ + BotID: parsed.BotID, + Secret: parsed.Secret, + }, func(frameCtx context.Context, frame WSFrame) error { + return a.handleFrame(frameCtx, cfg, frame, handler) + }) + if err != nil && connCtx.Err() == nil { + a.logger.Error("wecom websocket stopped", + slog.String("config_id", cfg.ID), + slog.Any("error", err), + ) + } + }() + + stop := func(context.Context) error { + cancel() + _ = client.Close() + <-done + a.mu.Lock() + if current, ok := a.clients[key]; ok && current == client { + delete(a.clients, key) + } + a.mu.Unlock() + return nil + } + return channel.NewConnection(cfg, stop), nil +} + +func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { + targetKind, targetID, ok := parseTarget(msg.Target) + if !ok { + return fmt.Errorf("wecom target is required") + } + parsed, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + client := a.getClient(parsed.BotID) + if client == nil { + return fmt.Errorf("wecom connection is not active") + } + if msg.Message.IsEmpty() { + return fmt.Errorf("message is required") + } + var ( + payload any + cmd string + reqID string + buildErr error + ) + if ctxMeta, ok := a.lookupCallbackContext(msg.Message.Reply); ok { + payload, cmd, reqID, buildErr = buildRespondPayload(msg.Message, ctxMeta.ReqID) + } else { + _ = targetKind + payload, cmd, reqID, buildErr = buildSendPayload(msg.Message, targetID) + } + if buildErr != nil { + return buildErr + } + ack, err := client.Reply(ctx, reqID, cmd, payload) + if err != nil { + return err + } + if ack.ErrCode != 0 { + return fmt.Errorf("wecom send failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode) + } + return nil +} + +func (a *WeComAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + target = strings.TrimSpace(target) + if target == "" { + return nil, fmt.Errorf("wecom target is required") + } + reply := opts.Reply + if reply == nil && strings.TrimSpace(opts.SourceMessageID) != "" { + reply = &channel.ReplyRef{ + Target: target, + MessageID: strings.TrimSpace(opts.SourceMessageID), + } + } + return &wecomOutboundStream{ + adapter: a, + cfg: cfg, + target: target, + reply: reply, + }, nil +} + +type wecomOutboundStream struct { + adapter *WeComAdapter + cfg channel.ChannelConfig + target string + reply *channel.ReplyRef + + mu sync.Mutex + closed atomic.Bool + finalSent atomic.Bool + textBuilder strings.Builder + attachments []channel.Attachment + final *channel.Message + streamID string + lastPreview string +} + +func (s *wecomOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { + if s.adapter == nil { + return errors.New("wecom stream not configured") + } + if s.closed.Load() { + return errors.New("wecom stream is closed") + } + if s.finalSent.Load() { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + switch event.Type { + case channel.StreamEventStatus, + channel.StreamEventPhaseStart, + channel.StreamEventPhaseEnd, + channel.StreamEventToolCallStart, + channel.StreamEventToolCallEnd, + channel.StreamEventAgentStart, + channel.StreamEventAgentEnd, + channel.StreamEventProcessingStarted, + channel.StreamEventProcessingCompleted, + channel.StreamEventProcessingFailed: + return nil + case channel.StreamEventDelta: + if strings.TrimSpace(event.Delta) == "" || event.Phase == channel.StreamPhaseReasoning { + return nil + } + s.mu.Lock() + s.textBuilder.WriteString(event.Delta) + s.mu.Unlock() + return s.pushPreview(ctx) + case channel.StreamEventAttachment: + if len(event.Attachments) == 0 { + return nil + } + s.mu.Lock() + s.attachments = append(s.attachments, event.Attachments...) + s.mu.Unlock() + return nil + case channel.StreamEventFinal: + if event.Final == nil { + return nil + } + s.mu.Lock() + final := event.Final.Message + s.final = &final + s.mu.Unlock() + return s.flush(ctx) + case channel.StreamEventError: + text := strings.TrimSpace(event.Error) + if text == "" { + return nil + } + s.mu.Lock() + s.final = &channel.Message{Format: channel.MessageFormatPlain, Text: "Error: " + text} + s.mu.Unlock() + return s.flush(ctx) + } + return nil +} + +func (s *wecomOutboundStream) Close(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + if s.finalSent.Load() { + return nil + } + return s.flush(ctx) +} + +func (s *wecomOutboundStream) flush(ctx context.Context) error { + if s.finalSent.Load() { + return nil + } + msg, streamID := s.snapshotMessage(true) + if msg.IsEmpty() { + return nil + } + if ctxMeta, ok := s.adapter.lookupCallbackContext(msg.Reply); ok { + if err := s.adapter.sendRespondStream(ctx, s.cfg, msg, ctxMeta.ReqID, streamID, true); err != nil { + return err + } + s.finalSent.Store(true) + return nil + } + if err := s.adapter.Send(ctx, s.cfg, channel.OutboundMessage{ + Target: s.target, + Message: msg, + }); err != nil { + return err + } + s.finalSent.Store(true) + return nil +} + +func (s *wecomOutboundStream) pushPreview(ctx context.Context) error { + if s.finalSent.Load() { + return nil + } + msg, streamID := s.snapshotMessage(false) + text := strings.TrimSpace(msg.PlainText()) + if text == "" { + return nil + } + s.mu.Lock() + if s.lastPreview == text { + s.mu.Unlock() + return nil + } + s.mu.Unlock() + if ctxMeta, ok := s.adapter.lookupCallbackContext(msg.Reply); ok { + if err := s.adapter.sendRespondStream(ctx, s.cfg, msg, ctxMeta.ReqID, streamID, false); err != nil { + return err + } + s.mu.Lock() + s.lastPreview = text + s.mu.Unlock() + } + return nil +} + +func (s *wecomOutboundStream) snapshotMessage(includeAttachments bool) (channel.Message, string) { + s.mu.Lock() + defer s.mu.Unlock() + msg := channel.Message{} + if s.final != nil { + msg = *s.final + } + if strings.TrimSpace(msg.Text) == "" { + msg.Text = strings.TrimSpace(s.textBuilder.String()) + } + if includeAttachments && len(msg.Attachments) == 0 && len(s.attachments) > 0 { + msg.Attachments = append(msg.Attachments, s.attachments...) + } + if msg.Reply == nil && s.reply != nil { + msg.Reply = s.reply + } + if s.streamID == "" { + s.streamID = NewReqID("stream") + } + return msg, s.streamID +} + +func (a *WeComAdapter) sendRespondStream(ctx context.Context, cfg channel.ChannelConfig, msg channel.Message, reqID string, streamID string, finish bool) error { + parsed, err := parseConfig(cfg.Credentials) + if err != nil { + return err + } + client := a.getClient(parsed.BotID) + if client == nil { + return fmt.Errorf("wecom connection is not active") + } + payload, cmd, ackReqID, err := buildRespondPayloadWithStream(msg, reqID, streamID, finish) + if err != nil { + return err + } + ack, err := client.Reply(ctx, ackReqID, cmd, payload) + if err != nil { + return err + } + if ack.ErrCode != 0 { + return fmt.Errorf("wecom send failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode) + } + return nil +} + +func secondsOrDefault(value int, fallback int) int { + if value > 0 { + return value + } + return fallback +} diff --git a/internal/channel/adapters/wecom/wecom_test.go b/internal/channel/adapters/wecom/wecom_test.go new file mode 100644 index 00000000..e58bfc1f --- /dev/null +++ b/internal/channel/adapters/wecom/wecom_test.go @@ -0,0 +1,52 @@ +package wecom + +import ( + "context" + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +func TestDiscoverSelf(t *testing.T) { + adapter := NewWeComAdapter(nil) + identity, externalID, err := adapter.DiscoverSelf(context.Background(), map[string]any{ + "botId": "bot_123", + "secret": "sec", + }) + if err != nil { + t.Fatalf("DiscoverSelf error = %v", err) + } + if externalID != "bot_123" { + t.Fatalf("unexpected external id: %q", externalID) + } + if identity["bot_id"] != "bot_123" { + t.Fatalf("unexpected bot_id: %v", identity["bot_id"]) + } + if identity["aibot_id"] != "bot_123" { + t.Fatalf("unexpected aibot_id: %v", identity["aibot_id"]) + } + if _, ok := identity["name"]; ok { + t.Fatalf("unexpected name field: %v", identity["name"]) + } + if _, ok := identity["display_name"]; ok { + t.Fatalf("unexpected display_name field: %v", identity["display_name"]) + } +} + +func TestOpenStream_FallbackReplyFromSourceMessageID(t *testing.T) { + adapter := NewWeComAdapter(nil) + stream, err := adapter.OpenStream(context.Background(), channel.ChannelConfig{}, "chat_id:chat_1", channel.StreamOptions{ + SourceMessageID: "msg_1", + }) + if err != nil { + t.Fatalf("OpenStream error = %v", err) + } + ws, ok := stream.(*wecomOutboundStream) + if !ok { + t.Fatalf("unexpected stream type: %T", stream) + } + if ws.reply == nil || ws.reply.MessageID != "msg_1" || ws.reply.Target != "chat_id:chat_1" { + t.Fatalf("unexpected reply fallback: %+v", ws.reply) + } +} + diff --git a/internal/channel/adapters/wecom/ws_client.go b/internal/channel/adapters/wecom/ws_client.go new file mode 100644 index 00000000..04edcf42 --- /dev/null +++ b/internal/channel/adapters/wecom/ws_client.go @@ -0,0 +1,396 @@ +package wecom + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type WSClientOptions struct { + URL string + Dialer *websocket.Dialer + Logger *slog.Logger + AckTimeout time.Duration + WriteTimeout time.Duration + ReadTimeout time.Duration + HeartbeatInterval time.Duration + ReconnectBaseDelay time.Duration + ReconnectMaxDelay time.Duration + MaxReconnectAttempts int +} + +type WSClient struct { + opts WSClientOptions + logger *slog.Logger + + writeMu sync.Mutex + waitMu sync.Mutex + connMu sync.RWMutex + + conn *websocket.Conn + waiters map[string]chan wsAck + closed bool +} + +type wsAck struct { + frame WSFrame + err error +} + +func NewWSClient(opts WSClientOptions) *WSClient { + if strings.TrimSpace(opts.URL) == "" { + opts.URL = defaultWSURL + } + if opts.Logger == nil { + opts.Logger = slog.Default() + } + if opts.AckTimeout <= 0 { + opts.AckTimeout = 8 * time.Second + } + if opts.WriteTimeout <= 0 { + opts.WriteTimeout = 8 * time.Second + } + if opts.ReadTimeout <= 0 { + opts.ReadTimeout = 70 * time.Second + } + if opts.HeartbeatInterval <= 0 { + opts.HeartbeatInterval = 30 * time.Second + } + if opts.ReconnectBaseDelay <= 0 { + opts.ReconnectBaseDelay = 1 * time.Second + } + if opts.ReconnectMaxDelay <= 0 { + opts.ReconnectMaxDelay = 30 * time.Second + } + return &WSClient{ + opts: opts, + logger: opts.Logger.With(slog.String("component", "wecom_ws_client")), + waiters: make(map[string]chan wsAck), + } +} + +func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error { + if err := auth.Validate(); err != nil { + return err + } + attempt := 0 + for { + err := c.runSession(ctx, auth, onFrame) + if ctx.Err() != nil { + return ctx.Err() + } + if c.isClosed() { + return nil + } + if c.opts.MaxReconnectAttempts >= 0 && attempt >= c.opts.MaxReconnectAttempts { + if err == nil { + return fmt.Errorf("wecom websocket reconnect attempts exceeded") + } + return err + } + delay := c.backoff(attempt) + attempt++ + c.logger.Warn("wecom websocket session ended; reconnecting", + slog.Int("attempt", attempt), + slog.Duration("delay", delay), + slog.Any("error", err), + ) + timer := time.NewTimer(delay) + select { + case <-ctx.Done(): + timer.Stop() + return ctx.Err() + case <-timer.C: + } + } +} + +func (c *WSClient) runSession(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error { + conn, _, err := c.dial(ctx) + if err != nil { + return err + } + c.setConn(conn) + sessionCtx, cancel := context.WithCancel(ctx) + defer func() { + cancel() + c.clearConn() + c.failAllWaiters(fmt.Errorf("wecom websocket disconnected")) + }() + + readErrCh := make(chan error, 1) + go c.readLoop(sessionCtx, onFrame, readErrCh) + + if err := c.authenticate(sessionCtx, auth); err != nil { + _ = conn.Close() + return err + } + c.logger.Info("wecom websocket authenticated") + + go c.heartbeatLoop(sessionCtx) + + select { + case <-sessionCtx.Done(): + _ = conn.Close() + return sessionCtx.Err() + case err := <-readErrCh: + _ = conn.Close() + return err + } +} + +func (c *WSClient) dial(ctx context.Context) (*websocket.Conn, *http.Response, error) { + dialer := c.opts.Dialer + if dialer == nil { + dialer = websocket.DefaultDialer + } + return dialer.DialContext(ctx, c.opts.URL, nil) +} + +func (c *WSClient) authenticate(ctx context.Context, auth AuthCredentials) error { + frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), SubscribeBody{ + BotID: strings.TrimSpace(auth.BotID), + Secret: strings.TrimSpace(auth.Secret), + }) + if err != nil { + return err + } + ack, err := c.SendWithAck(ctx, frame) + if err != nil { + return err + } + if ack.ErrCode != 0 { + return fmt.Errorf("wecom subscribe failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode) + } + return nil +} + +func (c *WSClient) heartbeatLoop(ctx context.Context) { + ticker := time.NewTicker(c.opts.HeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + frame, err := BuildFrame(WSCmdHeartbeat, NewReqID(WSCmdHeartbeat), nil) + if err != nil { + c.logger.Error("build heartbeat frame failed", slog.Any("error", err)) + continue + } + err = c.Send(ctx, frame) + if err != nil { + c.logger.Warn("wecom websocket heartbeat failed", slog.Any("error", err)) + if conn := c.getConn(); conn != nil { + _ = conn.Close() + } + return + } + } + } +} + +func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, WSFrame) error, errCh chan<- error) { + conn := c.getConn() + if conn == nil { + errCh <- fmt.Errorf("wecom websocket connection not ready") + return + } + for { + if c.opts.ReadTimeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout)) + } + _, payload, err := conn.ReadMessage() + if err != nil { + errCh <- err + return + } + var frame WSFrame + if err := json.Unmarshal(payload, &frame); err != nil { + c.logger.Warn("decode websocket frame failed", slog.Any("error", err)) + continue + } + if c.dispatchAck(frame) { + continue + } + if onFrame == nil { + continue + } + if err := onFrame(ctx, frame); err != nil { + c.logger.Warn("wecom onFrame callback returned error", slog.Any("error", err)) + } + } +} + +func (c *WSClient) Send(ctx context.Context, frame WSFrame) error { + if strings.TrimSpace(frame.Headers.ReqID) == "" { + return fmt.Errorf("req_id is required") + } + conn := c.getConn() + if conn == nil { + return fmt.Errorf("wecom websocket is not connected") + } + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.opts.WriteTimeout > 0 { + _ = conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout)) + } + if err := conn.WriteJSON(frame); err != nil { + return err + } + return nil +} + +func (c *WSClient) SendWithAck(ctx context.Context, frame WSFrame) (WSFrame, error) { + reqID := strings.TrimSpace(frame.Headers.ReqID) + if reqID == "" { + return WSFrame{}, fmt.Errorf("req_id is required") + } + wait := make(chan wsAck, 1) + c.waitMu.Lock() + c.waiters[reqID] = wait + c.waitMu.Unlock() + defer func() { + c.waitMu.Lock() + delete(c.waiters, reqID) + c.waitMu.Unlock() + }() + if err := c.Send(ctx, frame); err != nil { + return WSFrame{}, err + } + timeout := c.opts.AckTimeout + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 && remaining < timeout { + timeout = remaining + } + } + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case <-ctx.Done(): + return WSFrame{}, ctx.Err() + case <-timer.C: + return WSFrame{}, fmt.Errorf("wait websocket ack timeout for req_id=%s", reqID) + case ack := <-wait: + if ack.err != nil { + return WSFrame{}, ack.err + } + return ack.frame, nil + } +} + +func (c *WSClient) Close() error { + c.connMu.Lock() + c.closed = true + conn := c.conn + c.conn = nil + c.connMu.Unlock() + c.failAllWaiters(fmt.Errorf("wecom websocket client closed")) + if conn == nil { + return nil + } + return conn.Close() +} + +func (c *WSClient) Reply(ctx context.Context, reqID string, cmd string, body any) (WSFrame, error) { + frame, err := BuildFrame(cmd, reqID, body) + if err != nil { + return WSFrame{}, err + } + // WeCom callback reply commands are triggered by inbound callback req_id and may + // not always return an explicit ACK frame in production. Waiting for ACK here can + // cause false timeouts even when the platform accepts the reply. + if isRespondCommand(cmd) { + if err := c.Send(ctx, frame); err != nil { + return WSFrame{}, err + } + return WSFrame{}, nil + } + return c.SendWithAck(ctx, frame) +} + +func isRespondCommand(cmd string) bool { + switch strings.TrimSpace(cmd) { + case WSCmdRespond, WSCmdRespondWelcome, WSCmdRespondUpdate: + return true + default: + return false + } +} + +func (c *WSClient) dispatchAck(frame WSFrame) bool { + reqID := strings.TrimSpace(frame.Headers.ReqID) + if reqID == "" { + return false + } + c.waitMu.Lock() + wait, ok := c.waiters[reqID] + c.waitMu.Unlock() + if !ok { + return false + } + select { + case wait <- wsAck{frame: frame}: + default: + } + return true +} + +func (c *WSClient) failAllWaiters(cause error) { + c.waitMu.Lock() + defer c.waitMu.Unlock() + for id, wait := range c.waiters { + delete(c.waiters, id) + select { + case wait <- wsAck{err: cause}: + default: + } + } +} + +func (c *WSClient) backoff(attempt int) time.Duration { + if attempt < 0 { + attempt = 0 + } + delay := c.opts.ReconnectBaseDelay << attempt + if delay > c.opts.ReconnectMaxDelay { + return c.opts.ReconnectMaxDelay + } + return delay +} + +func (c *WSClient) setConn(conn *websocket.Conn) { + c.connMu.Lock() + defer c.connMu.Unlock() + c.conn = conn +} + +func (c *WSClient) clearConn() { + c.connMu.Lock() + conn := c.conn + c.conn = nil + c.connMu.Unlock() + if conn != nil { + _ = conn.Close() + } +} + +func (c *WSClient) getConn() *websocket.Conn { + c.connMu.RLock() + defer c.connMu.RUnlock() + return c.conn +} + +func (c *WSClient) isClosed() bool { + c.connMu.RLock() + defer c.connMu.RUnlock() + return c.closed +} diff --git a/internal/channel/adapters/wecom/ws_client_test.go b/internal/channel/adapters/wecom/ws_client_test.go new file mode 100644 index 00000000..7100dc41 --- /dev/null +++ b/internal/channel/adapters/wecom/ws_client_test.go @@ -0,0 +1,178 @@ +package wecom + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) { + t.Parallel() + + var connCount atomic.Int32 + callbackSent := make(chan struct{}, 1) + upgrader := websocket.Upgrader{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + n := connCount.Add(1) + + var subscribeFrame WSFrame + if err := conn.ReadJSON(&subscribeFrame); err != nil { + return + } + _ = conn.WriteJSON(WSFrame{ + Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID}, + ErrCode: 0, + }) + + if n == 1 { + return + } + body, _ := json.Marshal(MessageCallbackBody{ + MsgID: "m1", + ChatID: "chat_1", + ChatType: "group", + CreateTime: time.Now().UnixMilli(), + From: CallbackFrom{UserID: "u1"}, + MsgType: "text", + Text: &MessageText{Content: "hello"}, + }) + _ = conn.WriteJSON(WSFrame{ + Cmd: WSCmdMsgCallback, + Headers: WSHeaders{ReqID: "cb_req"}, + Body: body, + }) + select { + case callbackSent <- struct{}{}: + default: + } + <-time.After(100 * time.Millisecond) + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + client := NewWSClient(WSClientOptions{ + URL: wsURL, + AckTimeout: 200 * time.Millisecond, + HeartbeatInterval: 10 * time.Second, + ReconnectBaseDelay: 10 * time.Millisecond, + ReconnectMaxDelay: 20 * time.Millisecond, + MaxReconnectAttempts: 5, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + runErrCh := make(chan error, 1) + go func() { + runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, func(context.Context, WSFrame) error { + cancel() + return nil + }) + }() + + select { + case <-callbackSent: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting callback on reconnected session") + } + select { + case err := <-runErrCh: + if err == nil || err != context.Canceled { + t.Fatalf("unexpected run error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting run return") + } + if connCount.Load() < 2 { + t.Fatalf("expected reconnect attempts >= 2, got %d", connCount.Load()) + } +} + +func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) { + t.Parallel() + + var connCount atomic.Int32 + heartbeatSeen := make(chan struct{}, 1) + upgrader := websocket.Upgrader{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + n := connCount.Add(1) + + var subscribeFrame WSFrame + if err := conn.ReadJSON(&subscribeFrame); err != nil { + return + } + _ = conn.WriteJSON(WSFrame{ + Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID}, + ErrCode: 0, + }) + if n > 1 { + return + } + var heartbeatFrame WSFrame + if err := conn.ReadJSON(&heartbeatFrame); err != nil { + return + } + if heartbeatFrame.Cmd == WSCmdHeartbeat { + select { + case heartbeatSeen <- struct{}{}: + default: + } + } + <-time.After(200 * time.Millisecond) + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + client := NewWSClient(WSClientOptions{ + URL: wsURL, + HeartbeatInterval: 20 * time.Millisecond, + ReconnectBaseDelay: 10 * time.Millisecond, + ReconnectMaxDelay: 20 * time.Millisecond, + MaxReconnectAttempts: 5, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + runErrCh := make(chan error, 1) + go func() { + runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, nil) + }() + + select { + case <-heartbeatSeen: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting heartbeat frame") + } + <-time.After(250 * time.Millisecond) + cancel() + select { + case err := <-runErrCh: + if err == nil || err != context.Canceled { + t.Fatalf("unexpected run error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting run return") + } + if connCount.Load() < 1 { + t.Fatalf("expected at least one session, got %d", connCount.Load()) + } +} +