fix(wecom): align adapter with channel stream behavior

Migrate the imported WeCom adapter to current channel interfaces and stabilize stream delivery by preventing heartbeat/reply ACK timeout regressions and post-final overwrite updates.
This commit is contained in:
BBQ
2026-03-10 17:43:47 +08:00
committed by 晨苒
parent ef7ed961a9
commit bc47655309
21 changed files with 3143 additions and 0 deletions
@@ -0,0 +1,157 @@
package wecom
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/memohai/memoh/internal/channel"
)
func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) {
t.Parallel()
upgrader := websocket.Upgrader{}
receivedRespond := make(chan WSFrame, 1)
serverErr := make(chan error, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
select {
case serverErr <- err:
default:
}
return
}
defer conn.Close()
var subscribeFrame WSFrame
if err := conn.ReadJSON(&subscribeFrame); err != nil {
select {
case serverErr <- err:
default:
}
return
}
if err := conn.WriteJSON(WSFrame{
Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID},
ErrCode: 0,
}); err != nil {
select {
case serverErr <- err:
default:
}
return
}
body, _ := json.Marshal(MessageCallbackBody{
MsgID: "msg_1",
ChatID: "chat_1",
ChatType: "group",
CreateTime: time.Now().UnixMilli(),
From: CallbackFrom{UserID: "u1"},
MsgType: "text",
ResponseURL: "https://example.com/resp",
Text: &MessageText{Content: "hello"},
})
if err := conn.WriteJSON(WSFrame{
Cmd: WSCmdMsgCallback,
Headers: WSHeaders{ReqID: "callback_req_id"},
Body: body,
}); err != nil {
select {
case serverErr <- err:
default:
}
return
}
var respondFrame WSFrame
if err := conn.ReadJSON(&respondFrame); err != nil {
select {
case serverErr <- err:
default:
}
return
}
select {
case receivedRespond <- respondFrame:
default:
}
_ = conn.WriteJSON(WSFrame{
Headers: WSHeaders{ReqID: respondFrame.Headers.ReqID},
ErrCode: 0,
})
}))
defer server.Close()
adapter := NewWeComAdapter(nil)
cfg := channel.ChannelConfig{
Credentials: map[string]any{
"botId": "bot",
"secret": "sec",
"wsUrl": "ws" + strings.TrimPrefix(server.URL, "http"),
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
inboundCh := make(chan channel.InboundMessage, 1)
conn, err := adapter.Connect(ctx, cfg, func(_ context.Context, _ channel.ChannelConfig, msg channel.InboundMessage) error {
select {
case inboundCh <- msg:
default:
}
return nil
})
if err != nil {
t.Fatalf("connect error: %v", err)
}
defer conn.Stop(context.Background())
select {
case inbound := <-inboundCh:
if inbound.Message.ID != "msg_1" {
t.Fatalf("unexpected inbound message id: %s", inbound.Message.ID)
}
case err := <-serverErr:
t.Fatalf("server error: %v", err)
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting inbound callback")
}
err = adapter.Send(context.Background(), cfg, channel.OutboundMessage{
Target: "chat_id:chat_1",
Message: channel.Message{
Format: channel.MessageFormatPlain,
Text: "reply content",
Reply: &channel.ReplyRef{
MessageID: "msg_1",
},
},
})
if err != nil {
t.Fatalf("send error: %v", err)
}
select {
case frame := <-receivedRespond:
if frame.Cmd != WSCmdRespond {
t.Fatalf("expected cmd=%s got=%s", WSCmdRespond, frame.Cmd)
}
if frame.Headers.ReqID != "callback_req_id" {
t.Fatalf("expected req_id callback_req_id got=%s", frame.Headers.ReqID)
}
case err := <-serverErr:
t.Fatalf("server error: %v", err)
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting respond frame")
}
}
@@ -0,0 +1,77 @@
package wecom
import (
"strings"
"sync"
"time"
)
type callbackContext struct {
ReqID string
ResponseURL string
ChatID string
UserID string
CreatedAt time.Time
}
type callbackContextCache struct {
mu sync.RWMutex
items map[string]callbackContext
ttl time.Duration
}
func newCallbackContextCache(ttl time.Duration) *callbackContextCache {
if ttl <= 0 {
ttl = 24 * time.Hour
}
return &callbackContextCache{
items: make(map[string]callbackContext),
ttl: ttl,
}
}
func (c *callbackContextCache) Put(messageID string, ctx callbackContext) {
key := strings.TrimSpace(messageID)
if key == "" {
return
}
if ctx.CreatedAt.IsZero() {
ctx.CreatedAt = time.Now().UTC()
}
c.mu.Lock()
c.items[key] = ctx
c.gcLocked()
c.mu.Unlock()
}
func (c *callbackContextCache) Get(messageID string) (callbackContext, bool) {
key := strings.TrimSpace(messageID)
if key == "" {
return callbackContext{}, false
}
c.mu.RLock()
item, ok := c.items[key]
c.mu.RUnlock()
if !ok {
return callbackContext{}, false
}
if time.Since(item.CreatedAt) > c.ttl {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return callbackContext{}, false
}
return item, true
}
func (c *callbackContextCache) gcLocked() {
if len(c.items) < 512 {
return
}
now := time.Now().UTC()
for key, item := range c.items {
if now.Sub(item.CreatedAt) > c.ttl {
delete(c.items, key)
}
}
}
@@ -0,0 +1,29 @@
package wecom
import (
"testing"
"time"
)
func TestCallbackContextCache_PutGet(t *testing.T) {
cache := newCallbackContextCache(1 * time.Hour)
cache.Put("m1", callbackContext{ReqID: "r1"})
got, ok := cache.Get("m1")
if !ok {
t.Fatal("expected cache hit")
}
if got.ReqID != "r1" {
t.Fatalf("unexpected req id: %q", got.ReqID)
}
}
func TestCallbackContextCache_Expires(t *testing.T) {
cache := newCallbackContextCache(1 * time.Second)
cache.Put("m1", callbackContext{
ReqID: "r1",
CreatedAt: time.Now().Add(-2 * time.Second),
})
if _, ok := cache.Get("m1"); ok {
t.Fatal("expected cache miss due to expiry")
}
}
+210
View File
@@ -0,0 +1,210 @@
package wecom
import (
"fmt"
"strconv"
"strings"
"github.com/memohai/memoh/internal/channel"
)
type Config struct {
BotID string
Secret string
WSURL string
HeartbeatSeconds int
AckTimeoutSeconds int
WriteTimeoutSeconds int
ReadTimeoutSeconds int
}
type UserConfig struct {
ChatID string
UserID string
}
func normalizeConfig(raw map[string]any) (map[string]any, error) {
cfg, err := parseConfig(raw)
if err != nil {
return nil, err
}
out := map[string]any{
"botId": cfg.BotID,
"secret": cfg.Secret,
}
if cfg.WSURL != "" {
out["wsUrl"] = cfg.WSURL
}
if cfg.HeartbeatSeconds > 0 {
out["heartbeatSeconds"] = cfg.HeartbeatSeconds
}
if cfg.AckTimeoutSeconds > 0 {
out["ackTimeoutSeconds"] = cfg.AckTimeoutSeconds
}
if cfg.WriteTimeoutSeconds > 0 {
out["writeTimeoutSeconds"] = cfg.WriteTimeoutSeconds
}
if cfg.ReadTimeoutSeconds > 0 {
out["readTimeoutSeconds"] = cfg.ReadTimeoutSeconds
}
return out, nil
}
func normalizeUserConfig(raw map[string]any) (map[string]any, error) {
cfg, err := parseUserConfig(raw)
if err != nil {
return nil, err
}
out := map[string]any{}
if cfg.ChatID != "" {
out["chat_id"] = cfg.ChatID
}
if cfg.UserID != "" {
out["user_id"] = cfg.UserID
}
return out, nil
}
func parseConfig(raw map[string]any) (Config, error) {
cfg := Config{
BotID: strings.TrimSpace(channel.ReadString(raw, "botId", "bot_id")),
Secret: strings.TrimSpace(channel.ReadString(raw, "secret")),
WSURL: strings.TrimSpace(channel.ReadString(raw, "wsUrl", "ws_url")),
}
if value, ok := readInt(raw, "heartbeatSeconds", "heartbeat_seconds"); ok {
cfg.HeartbeatSeconds = value
}
if value, ok := readInt(raw, "ackTimeoutSeconds", "ack_timeout_seconds"); ok {
cfg.AckTimeoutSeconds = value
}
if value, ok := readInt(raw, "writeTimeoutSeconds", "write_timeout_seconds"); ok {
cfg.WriteTimeoutSeconds = value
}
if value, ok := readInt(raw, "readTimeoutSeconds", "read_timeout_seconds"); ok {
cfg.ReadTimeoutSeconds = value
}
if cfg.BotID == "" || cfg.Secret == "" {
return Config{}, fmt.Errorf("wecom botId and secret are required")
}
return cfg, nil
}
func parseUserConfig(raw map[string]any) (UserConfig, error) {
cfg := UserConfig{
ChatID: strings.TrimSpace(channel.ReadString(raw, "chatId", "chat_id")),
UserID: strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")),
}
if cfg.ChatID == "" && cfg.UserID == "" {
return UserConfig{}, fmt.Errorf("wecom user config requires chat_id or user_id")
}
return cfg, nil
}
func resolveTarget(raw map[string]any) (string, error) {
cfg, err := parseUserConfig(raw)
if err != nil {
return "", err
}
if cfg.ChatID != "" {
return "chat_id:" + cfg.ChatID, nil
}
return "user_id:" + cfg.UserID, nil
}
func normalizeTarget(raw string) string {
kind, id, ok := parseTarget(raw)
if !ok {
return ""
}
return kind + ":" + id
}
func parseTarget(raw string) (kind string, id string, ok bool) {
v := strings.TrimSpace(raw)
if v == "" {
return "", "", false
}
v = strings.TrimPrefix(v, "wecom:")
v = strings.TrimPrefix(v, "workwx:")
v = strings.TrimSpace(v)
lv := strings.ToLower(v)
switch {
case strings.HasPrefix(lv, "chat_id:"):
id = strings.TrimSpace(v[len("chat_id:"):])
return "chat_id", id, id != ""
case strings.HasPrefix(lv, "chat:"):
id = strings.TrimSpace(v[len("chat:"):])
return "chat_id", id, id != ""
case strings.HasPrefix(lv, "group:"):
id = strings.TrimSpace(v[len("group:"):])
return "chat_id", id, id != ""
case strings.HasPrefix(lv, "user_id:"):
id = strings.TrimSpace(v[len("user_id:"):])
return "user_id", id, id != ""
case strings.HasPrefix(lv, "user:"):
id = strings.TrimSpace(v[len("user:"):])
return "user_id", id, id != ""
default:
return "chat_id", v, true
}
}
func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool {
cfg, err := parseUserConfig(raw)
if err != nil {
return false
}
if value := strings.TrimSpace(criteria.Attribute("chat_id")); value != "" && value == cfg.ChatID {
return true
}
if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID {
return true
}
if criteria.SubjectID != "" && (criteria.SubjectID == cfg.ChatID || criteria.SubjectID == cfg.UserID) {
return true
}
return false
}
func buildUserConfig(identity channel.Identity) map[string]any {
out := map[string]any{}
if v := strings.TrimSpace(identity.Attribute("chat_id")); v != "" {
out["chat_id"] = v
}
if v := strings.TrimSpace(identity.Attribute("user_id")); v != "" {
out["user_id"] = v
}
return out
}
func readInt(raw map[string]any, keys ...string) (int, bool) {
for _, key := range keys {
value, ok := raw[key]
if !ok {
continue
}
switch v := value.(type) {
case int:
return v, true
case int32:
return int(v), true
case int64:
return int(v), true
case float64:
return int(v), true
case float32:
return int(v), true
case string:
trimmed := strings.TrimSpace(v)
if trimmed == "" {
continue
}
parsed, err := strconv.Atoi(trimmed)
if err != nil {
continue
}
return parsed, true
}
}
return 0, false
}
@@ -0,0 +1,27 @@
package wecom
import "testing"
func TestParseConfig(t *testing.T) {
cfg, err := parseConfig(map[string]any{
"botId": "bot-1",
"secret": "sec-1",
})
if err != nil {
t.Fatalf("parseConfig error = %v", err)
}
if cfg.BotID != "bot-1" || cfg.Secret != "sec-1" {
t.Fatalf("unexpected config: %+v", cfg)
}
}
func TestParseTarget(t *testing.T) {
kind, id, ok := parseTarget("chat_id:abc")
if !ok || kind != "chat_id" || id != "abc" {
t.Fatalf("unexpected target parse result: ok=%v kind=%q id=%q", ok, kind, id)
}
kind, id, ok = parseTarget("user_id:zhangsan")
if !ok || kind != "user_id" || id != "zhangsan" {
t.Fatalf("unexpected target parse result: ok=%v kind=%q id=%q", ok, kind, id)
}
}
+52
View File
@@ -0,0 +1,52 @@
package wecom
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
)
func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error) {
if len(ciphertext) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
}
key, err := base64.StdEncoding.DecodeString(aesKeyBase64)
if err != nil {
return nil, fmt.Errorf("decode aes key failed: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("invalid aes key length: %d", len(key))
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("invalid ciphertext block size")
}
iv := key[:aes.BlockSize]
out := make([]byte, len(ciphertext))
cipher.NewCBCDecrypter(block, iv).CryptBlocks(out, ciphertext)
plain, err := pkcs7Unpad(out, 32)
if err != nil {
return nil, err
}
return plain, nil
}
func pkcs7Unpad(data []byte, maxPad int) ([]byte, error) {
if len(data) == 0 {
return nil, fmt.Errorf("pkcs7 payload is empty")
}
pad := int(data[len(data)-1])
if pad <= 0 || pad > maxPad || pad > len(data) {
return nil, fmt.Errorf("invalid pkcs7 padding length: %d", pad)
}
padding := bytes.Repeat([]byte{byte(pad)}, pad)
if !bytes.Equal(data[len(data)-pad:], padding) {
return nil, fmt.Errorf("invalid pkcs7 padding bytes")
}
return data[:len(data)-pad], nil
}
@@ -0,0 +1,50 @@
package wecom
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"testing"
)
func TestDecryptFileAES256CBC(t *testing.T) {
key := []byte("0123456789abcdef0123456789abcdef")
plain := []byte("hello-wecom-aibot")
ciphertext := encryptPKCS7To32(t, key, plain)
out, err := DecryptFileAES256CBC(ciphertext, base64.StdEncoding.EncodeToString(key))
if err != nil {
t.Fatalf("DecryptFileAES256CBC error = %v", err)
}
if !bytes.Equal(out, plain) {
t.Fatalf("plaintext mismatch: got=%q want=%q", string(out), string(plain))
}
}
func encryptPKCS7To32(t *testing.T, key []byte, plain []byte) []byte {
t.Helper()
padded := pkcs7PadTo32(plain)
block, err := aes.NewCipher(key)
if err != nil {
t.Fatal(err)
}
iv := key[:aes.BlockSize]
out := make([]byte, len(padded))
cipher.NewCBCEncrypter(block, iv).CryptBlocks(out, padded)
return out
}
func pkcs7PadTo32(data []byte) []byte {
const blockSize = 32
pad := blockSize - (len(data) % blockSize)
if pad == 0 {
pad = blockSize
}
out := make([]byte, len(data)+pad)
copy(out, data)
for i := len(data); i < len(out); i++ {
out[i] = byte(pad)
}
return out
}
@@ -0,0 +1,167 @@
package wecom
import (
"context"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"net/url"
"path"
"strings"
"time"
)
type HTTPClientOptions struct {
Logger *slog.Logger
Client *http.Client
Transport *http.Transport
Timeout time.Duration
MaxIdleConns int
MaxIdleConnsPerHost int
IdleConnTimeout time.Duration
TLSHandshakeTimeout time.Duration
ResponseHeaderTimeout time.Duration
}
type HTTPClient struct {
client *http.Client
logger *slog.Logger
}
type DownloadedFile struct {
Data []byte
FileName string
ContentType string
}
func NewHTTPClient(opts HTTPClientOptions) *HTTPClient {
if opts.Logger == nil {
opts.Logger = slog.Default()
}
if opts.Timeout <= 0 {
opts.Timeout = 20 * time.Second
}
if opts.IdleConnTimeout <= 0 {
opts.IdleConnTimeout = 90 * time.Second
}
if opts.TLSHandshakeTimeout <= 0 {
opts.TLSHandshakeTimeout = 10 * time.Second
}
if opts.ResponseHeaderTimeout <= 0 {
opts.ResponseHeaderTimeout = 15 * time.Second
}
if opts.MaxIdleConns <= 0 {
opts.MaxIdleConns = 100
}
if opts.MaxIdleConnsPerHost <= 0 {
opts.MaxIdleConnsPerHost = 10
}
client := opts.Client
if client == nil {
transport := opts.Transport
if transport == nil {
transport = &http.Transport{
MaxIdleConns: opts.MaxIdleConns,
MaxIdleConnsPerHost: opts.MaxIdleConnsPerHost,
IdleConnTimeout: opts.IdleConnTimeout,
TLSHandshakeTimeout: opts.TLSHandshakeTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
}
client = &http.Client{
Transport: transport,
Timeout: opts.Timeout,
}
}
return &HTTPClient{
client: client,
logger: opts.Logger.With(slog.String("component", "wecom_http_client")),
}
}
func (c *HTTPClient) DownloadFile(ctx context.Context, rawURL string) (DownloadedFile, error) {
u := strings.TrimSpace(rawURL)
if u == "" {
return DownloadedFile{}, fmt.Errorf("download url is required")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return DownloadedFile{}, err
}
resp, err := c.client.Do(req)
if err != nil {
return DownloadedFile{}, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return DownloadedFile{}, fmt.Errorf("download failed with status %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return DownloadedFile{}, err
}
fileName := extractFilename(resp.Header.Get("Content-Disposition"), u)
return DownloadedFile{
Data: data,
FileName: fileName,
ContentType: strings.TrimSpace(resp.Header.Get("Content-Type")),
}, nil
}
func (c *HTTPClient) DownloadAndDecryptFile(ctx context.Context, rawURL string, aesKey string) (DownloadedFile, error) {
file, err := c.DownloadFile(ctx, rawURL)
if err != nil {
return DownloadedFile{}, err
}
plain, err := DecryptFileAES256CBC(file.Data, aesKey)
if err != nil {
return DownloadedFile{}, err
}
file.Data = plain
return file, nil
}
func extractFilename(contentDisposition, rawURL string) string {
cd := strings.TrimSpace(contentDisposition)
if cd != "" {
_, params, err := mime.ParseMediaType(cd)
if err == nil {
if v := strings.TrimSpace(params["filename*"]); v != "" {
if decoded := decodeRFC5987Filename(v); decoded != "" {
return decoded
}
return v
}
if v := strings.TrimSpace(params["filename"]); v != "" {
return v
}
}
}
parsed, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return ""
}
base := strings.TrimSpace(path.Base(parsed.Path))
if base == "." || base == "/" {
return ""
}
return base
}
func decodeRFC5987Filename(value string) string {
parts := strings.SplitN(strings.TrimSpace(value), "''", 2)
encoded := strings.TrimSpace(value)
if len(parts) == 2 {
encoded = strings.TrimSpace(parts[1])
}
if encoded == "" {
return ""
}
decoded, err := url.QueryUnescape(encoded)
if err != nil {
return ""
}
return strings.TrimSpace(decoded)
}
@@ -0,0 +1,36 @@
package wecom
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestDownloadFile_ParsesFilename(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Disposition", "attachment; filename*=UTF-8''hello%20wecom.txt")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer srv.Close()
client := NewHTTPClient(HTTPClientOptions{})
file, err := client.DownloadFile(context.Background(), srv.URL+"/download.bin")
if err != nil {
t.Fatalf("DownloadFile error = %v", err)
}
if file.FileName != "hello wecom.txt" {
t.Fatalf("unexpected filename: got=%q", file.FileName)
}
if string(file.Data) != "ok" {
t.Fatalf("unexpected payload: got=%q", string(file.Data))
}
}
func TestExtractFilename_FallbackPath(t *testing.T) {
got := extractFilename("", "https://example.com/files/a.png")
if got != "a.png" {
t.Fatalf("unexpected filename fallback: got=%q", got)
}
}
+338
View File
@@ -0,0 +1,338 @@
package wecom
import (
"context"
"strings"
"time"
"github.com/memohai/memoh/internal/channel"
)
func (a *WeComAdapter) handleFrame(ctx context.Context, cfg channel.ChannelConfig, frame WSFrame, handler channel.InboundHandler) error {
switch frame.Cmd {
case WSCmdMsgCallback:
if handler == nil {
return nil
}
var body MessageCallbackBody
if err := frame.DecodeBody(&body); err != nil {
return err
}
msg, ok := buildInboundMessage(body, frame.Headers.ReqID)
if !ok {
return nil
}
a.rememberCallback(body.MsgID, frame.Headers.ReqID, body.ResponseURL, body.ChatID, body.From.UserID)
return handler(ctx, cfg, msg)
case WSCmdEventCallback:
if handler == nil {
return nil
}
var body EventCallbackBody
if err := frame.DecodeBody(&body); err != nil {
return err
}
msg, ok := buildInboundEventMessage(body, frame.Headers.ReqID)
if !ok {
return nil
}
a.rememberCallback(body.MsgID, frame.Headers.ReqID, body.ResponseURL, body.ChatID, body.From.UserID)
return handler(ctx, cfg, msg)
default:
return nil
}
}
func buildInboundMessage(body MessageCallbackBody, reqID string) (channel.InboundMessage, bool) {
text, attachments := extractBodyContent(body)
if strings.TrimSpace(text) == "" && len(attachments) == 0 {
return channel.InboundMessage{}, false
}
target := resolveDeliveryTarget(body.ChatID, body.From.UserID)
if target == "" {
return channel.InboundMessage{}, false
}
convType := normalizeConversationType(body.ChatType)
convID := strings.TrimSpace(body.ChatID)
if convID == "" {
convID = strings.TrimSpace(body.From.UserID)
}
msg := channel.InboundMessage{
Channel: Type,
Message: channel.Message{
ID: strings.TrimSpace(body.MsgID),
Format: channel.MessageFormatPlain,
Text: strings.TrimSpace(text),
Attachments: attachments,
Metadata: map[string]any{
"response_url": strings.TrimSpace(body.ResponseURL),
},
},
ReplyTarget: target,
Sender: channel.Identity{
SubjectID: strings.TrimSpace(body.From.UserID),
Attributes: map[string]string{
"user_id": strings.TrimSpace(body.From.UserID),
"chat_id": strings.TrimSpace(body.ChatID),
},
},
Conversation: channel.Conversation{
ID: convID,
Type: convType,
},
ReceivedAt: parseCreateTime(body.CreateTime),
Source: "wecom",
Metadata: map[string]any{
"req_id": strings.TrimSpace(reqID),
"chat_id": strings.TrimSpace(body.ChatID),
"chat_type": strings.TrimSpace(body.ChatType),
"response_url": strings.TrimSpace(body.ResponseURL),
},
}
return msg, true
}
func buildInboundEventMessage(body EventCallbackBody, reqID string) (channel.InboundMessage, bool) {
eventType := normalizeEventType(body.Event.EventType, body.Event.EventType2)
if eventType == "" {
return channel.InboundMessage{}, false
}
target := resolveDeliveryTarget(body.ChatID, body.From.UserID)
if target == "" {
return channel.InboundMessage{}, false
}
convType := normalizeConversationType(body.ChatType)
convID := strings.TrimSpace(body.ChatID)
if convID == "" {
convID = strings.TrimSpace(body.From.UserID)
}
return channel.InboundMessage{
Channel: Type,
Message: channel.Message{
ID: strings.TrimSpace(body.MsgID),
Format: channel.MessageFormatPlain,
Text: eventType,
Metadata: map[string]any{
"event_type": eventType,
"task_id": strings.TrimSpace(body.Event.TaskID),
"event_key": strings.TrimSpace(body.Event.EventKey),
"event_code": strings.TrimSpace(body.Event.Code),
"event_reason": strings.TrimSpace(body.Event.Reason),
"task_status": strings.TrimSpace(body.Task.TaskStatus),
"response_url": strings.TrimSpace(body.ResponseURL),
},
},
ReplyTarget: target,
Sender: channel.Identity{
SubjectID: strings.TrimSpace(body.From.UserID),
Attributes: map[string]string{
"user_id": strings.TrimSpace(body.From.UserID),
"chat_id": strings.TrimSpace(body.ChatID),
},
},
Conversation: channel.Conversation{
ID: convID,
Type: convType,
},
ReceivedAt: parseCreateTime(body.CreateTime),
Source: "wecom",
Metadata: map[string]any{
"is_event": true,
"event_type": eventType,
"event_key": strings.TrimSpace(body.Event.EventKey),
"event_code": strings.TrimSpace(body.Event.Code),
"event_reason": strings.TrimSpace(body.Event.Reason),
"req_id": strings.TrimSpace(reqID),
"response_url": strings.TrimSpace(body.ResponseURL),
},
}, true
}
func extractBodyContent(body MessageCallbackBody) (string, []channel.Attachment) {
switch strings.ToLower(strings.TrimSpace(body.MsgType)) {
case "text":
if body.Text == nil {
return "", nil
}
return strings.TrimSpace(body.Text.Content), nil
case "markdown":
if body.Markdown == nil {
return "", nil
}
return strings.TrimSpace(body.Markdown.Content), nil
case "voice":
if body.Voice == nil {
return "", nil
}
return strings.TrimSpace(body.Voice.Content), nil
case "video":
if body.Video == nil {
return "", nil
}
att := channel.Attachment{
Type: channel.AttachmentVideo,
URL: strings.TrimSpace(body.Video.URL),
SourcePlatform: Type.String(),
Metadata: map[string]any{
"aeskey": strings.TrimSpace(body.Video.AESKey),
},
}
return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)}
case "image":
if body.Image == nil {
return "", nil
}
att := channel.Attachment{
Type: channel.AttachmentImage,
URL: strings.TrimSpace(body.Image.URL),
SourcePlatform: Type.String(),
Metadata: map[string]any{
"aeskey": strings.TrimSpace(body.Image.AESKey),
},
}
return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)}
case "file":
if body.File == nil {
return "", nil
}
att := channel.Attachment{
Type: channel.AttachmentFile,
URL: strings.TrimSpace(body.File.URL),
SourcePlatform: Type.String(),
Name: strings.TrimSpace(body.File.FileName),
Metadata: map[string]any{
"aeskey": strings.TrimSpace(body.File.AESKey),
},
}
return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)}
case "mixed":
var textParts []string
attachments := make([]channel.Attachment, 0)
for _, item := range body.Mixed {
switch strings.ToLower(strings.TrimSpace(item.MsgType)) {
case "text":
if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" {
textParts = append(textParts, strings.TrimSpace(item.Text.Content))
}
case "markdown":
if item.Markdown != nil && strings.TrimSpace(item.Markdown.Content) != "" {
textParts = append(textParts, strings.TrimSpace(item.Markdown.Content))
}
case "image":
if item.Image == nil {
continue
}
att := channel.Attachment{
Type: channel.AttachmentImage,
URL: strings.TrimSpace(item.Image.URL),
SourcePlatform: Type.String(),
Metadata: map[string]any{
"aeskey": strings.TrimSpace(item.Image.AESKey),
},
}
attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att))
case "file":
if item.File == nil {
continue
}
att := channel.Attachment{
Type: channel.AttachmentFile,
URL: strings.TrimSpace(item.File.URL),
SourcePlatform: Type.String(),
Name: strings.TrimSpace(item.File.FileName),
Metadata: map[string]any{
"aeskey": strings.TrimSpace(item.File.AESKey),
},
}
attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att))
case "video":
if item.Video == nil {
continue
}
att := channel.Attachment{
Type: channel.AttachmentVideo,
URL: strings.TrimSpace(item.Video.URL),
SourcePlatform: Type.String(),
Metadata: map[string]any{
"aeskey": strings.TrimSpace(item.Video.AESKey),
},
}
attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att))
case "voice":
if item.Voice != nil && strings.TrimSpace(item.Voice.Content) != "" {
textParts = append(textParts, strings.TrimSpace(item.Voice.Content))
}
}
}
return strings.Join(textParts, "\n"), attachments
default:
return "", nil
}
}
func resolveDeliveryTarget(chatID, userID string) string {
if v := strings.TrimSpace(chatID); v != "" {
return "chat_id:" + v
}
if v := strings.TrimSpace(userID); v != "" {
return "user_id:" + v
}
return ""
}
func normalizeConversationType(chatType string) string {
if strings.EqualFold(strings.TrimSpace(chatType), "group") {
return "group"
}
return "private"
}
func parseCreateTime(ts int64) time.Time {
if t := unixMilliseconds(ts); !t.IsZero() {
return t
}
return time.Now().UTC()
}
func (a *WeComAdapter) rememberCallback(msgID, reqID, responseURL, chatID, userID string) {
if a == nil || a.cache == nil {
return
}
msgID = strings.TrimSpace(msgID)
if msgID == "" {
return
}
if strings.TrimSpace(reqID) == "" {
return
}
a.cache.Put(msgID, callbackContext{
ReqID: strings.TrimSpace(reqID),
ResponseURL: strings.TrimSpace(responseURL),
ChatID: strings.TrimSpace(chatID),
UserID: strings.TrimSpace(userID),
CreatedAt: time.Now().UTC(),
})
}
func normalizeEventType(values ...string) string {
candidates := make([]string, 0, len(values))
for _, v := range values {
v = strings.TrimSpace(strings.ToLower(v))
if v == "" {
continue
}
v = strings.ReplaceAll(v, "-", "_")
v = strings.ReplaceAll(v, " ", "_")
candidates = append(candidates, v)
}
for _, v := range candidates {
switch v {
case "enter_chat", "template_card_event", "feedback_event":
return v
}
}
if len(candidates) == 0 {
return ""
}
return candidates[0]
}
@@ -0,0 +1,103 @@
package wecom
import "testing"
func TestBuildInboundMessage_Text(t *testing.T) {
msg, ok := buildInboundMessage(MessageCallbackBody{
MsgID: "m1",
ChatID: "chat-1",
ChatType: "group",
From: CallbackFrom{
UserID: "u1",
},
MsgType: "text",
Text: &MessageText{
Content: "hello",
},
}, "req-1")
if !ok {
t.Fatal("expected inbound message")
}
if msg.Message.Text != "hello" {
t.Fatalf("unexpected text: %q", msg.Message.Text)
}
if msg.ReplyTarget != "chat_id:chat-1" {
t.Fatalf("unexpected target: %q", msg.ReplyTarget)
}
if msg.Metadata["req_id"] != "req-1" {
t.Fatalf("unexpected req_id: %v", msg.Metadata["req_id"])
}
}
func TestBuildInboundMessage_Markdown(t *testing.T) {
msg, ok := buildInboundMessage(MessageCallbackBody{
MsgID: "m2",
ChatID: "chat-2",
ChatType: "private",
From: CallbackFrom{UserID: "u2"},
MsgType: "markdown",
Markdown: &MessageText{
Content: "**hello**",
},
}, "req-2")
if !ok {
t.Fatal("expected inbound markdown message")
}
if msg.Message.Text != "**hello**" {
t.Fatalf("unexpected markdown text: %q", msg.Message.Text)
}
}
func TestBuildInboundMessage_Video(t *testing.T) {
msg, ok := buildInboundMessage(MessageCallbackBody{
MsgID: "m3",
ChatID: "chat-3",
ChatType: "group",
From: CallbackFrom{UserID: "u3"},
MsgType: "video",
Video: &MessageVideo{
URL: "https://example.com/v.mp4",
AESKey: "k",
},
}, "req-3")
if !ok {
t.Fatal("expected inbound video message")
}
if len(msg.Message.Attachments) != 1 {
t.Fatalf("expected one attachment, got %d", len(msg.Message.Attachments))
}
if msg.Message.Attachments[0].Type != "video" {
t.Fatalf("unexpected attachment type: %s", msg.Message.Attachments[0].Type)
}
}
func TestBuildInboundEventMessage_NormalizeEventType(t *testing.T) {
msg, ok := buildInboundEventMessage(EventCallbackBody{
MsgID: "e1",
ChatID: "chat-1",
ChatType: "group",
From: CallbackFrom{UserID: "u1"},
CreateTime: 1,
MsgType: "event",
Event: EventPayload{
EventType2: "template-card-event",
EventKey: "btn_1",
TaskID: "task_1",
},
Task: EventTask{
TaskStatus: "done",
},
}, "req-e1")
if !ok {
t.Fatal("expected event message")
}
if msg.Message.Text != "template_card_event" {
t.Fatalf("unexpected normalized event text: %q", msg.Message.Text)
}
if msg.Message.Metadata["event_key"] != "btn_1" {
t.Fatalf("unexpected event key: %v", msg.Message.Metadata["event_key"])
}
if msg.Message.Metadata["task_status"] != "done" {
t.Fatalf("unexpected task status: %v", msg.Message.Metadata["task_status"])
}
}
+309
View File
@@ -0,0 +1,309 @@
package wecom
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"fmt"
"io"
"strings"
"github.com/memohai/memoh/internal/channel"
)
func (a *WeComAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) {
_ = cfg
if a.http == nil {
return channel.AttachmentPayload{}, fmt.Errorf("wecom http client not configured")
}
url := strings.TrimSpace(attachment.URL)
if url == "" {
return channel.AttachmentPayload{}, fmt.Errorf("wecom attachment url is required")
}
aesKey := ""
if attachment.Metadata != nil {
if value, ok := attachment.Metadata["aeskey"].(string); ok {
aesKey = strings.TrimSpace(value)
}
}
var file DownloadedFile
var err error
if aesKey != "" {
file, err = a.http.DownloadAndDecryptFile(ctx, url, aesKey)
} else {
file, err = a.http.DownloadFile(ctx, url)
}
if err != nil {
return channel.AttachmentPayload{}, err
}
return channel.AttachmentPayload{
Reader: io.NopCloser(bytes.NewReader(file.Data)),
Mime: strings.TrimSpace(file.ContentType),
Name: strings.TrimSpace(file.FileName),
Size: int64(len(file.Data)),
}, nil
}
type markdownPayload struct {
Content string `json:"content"`
}
func buildSendPayload(msg channel.Message, targetID string) (any, string, string, error) {
if strings.TrimSpace(targetID) == "" {
return nil, "", "", fmt.Errorf("wecom target id is required")
}
reqID := NewReqID(WSCmdSendMessage)
if card, ok := readTemplateCard(msg.Metadata); ok {
return SendMessageTemplateCardBody{
ChatID: targetID,
MsgType: "template_card",
TemplateCard: card,
}, WSCmdSendMessage, reqID, nil
}
// aibot_send_msg currently supports markdown/template_card in official SDK.
// Attachments should be sent through callback-reply path (aibot_respond_msg).
if len(msg.Attachments) > 0 {
return nil, "", "", fmt.Errorf("wecom proactive send does not support attachments; use reply flow")
}
text := strings.TrimSpace(msg.PlainText())
if text == "" {
return nil, "", "", fmt.Errorf("wecom outbound text is required")
}
return SendMessageMarkdownBody{
ChatID: targetID,
MsgType: "markdown",
Markdown: markdownPayload{
Content: text,
},
}, WSCmdSendMessage, reqID, nil
}
func buildRespondPayload(msg channel.Message, replyReqID string) (any, string, string, error) {
return buildRespondPayloadWithStream(msg, replyReqID, "", true)
}
func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, streamID string, finish bool) (any, string, string, error) {
reqID := strings.TrimSpace(replyReqID)
if reqID == "" {
return nil, "", "", fmt.Errorf("reply req_id is required")
}
if finish {
if body, ok := buildWelcomePayload(msg); ok {
return body, WSCmdRespondWelcome, reqID, nil
}
if body, ok := readUpdateTemplateCard(msg.Metadata); ok {
return body, WSCmdRespondUpdate, reqID, nil
}
}
text := strings.TrimSpace(msg.PlainText())
if finish && text == "" && len(msg.Attachments) == 0 {
return nil, "", "", fmt.Errorf("wecom reply payload is empty")
}
if !finish && text == "" {
return nil, "", "", fmt.Errorf("wecom stream delta content is empty")
}
streamID = strings.TrimSpace(streamID)
if streamID == "" {
streamID = NewReqID("stream")
}
stream := StreamReplyBlock{
ID: streamID,
Finish: finish,
Content: text,
}
if feedbackID := readFeedbackID(msg.Metadata); feedbackID != "" {
stream.Feedback = &StreamReplyFeedback{ID: feedbackID}
}
if finish && len(msg.Attachments) > 0 {
first := msg.Attachments[0]
base64Data := extractBase64Content(first.Base64)
if base64Data != "" {
raw, err := base64.StdEncoding.DecodeString(base64Data)
if err == nil && len(raw) > 0 {
stream.MsgItems = []StreamReplyItem{
{
MsgType: "image",
Image: &StreamReplyImage{
Base64: base64.StdEncoding.EncodeToString(raw),
MD5: fmt.Sprintf("%x", md5.Sum(raw)),
},
},
}
}
}
}
if card, ok := readTemplateCard(msg.Metadata); ok {
return StreamWithTemplateCardReplyBody{
MsgType: "stream_with_template_card",
Stream: stream,
TemplateCard: card,
}, WSCmdRespond, reqID, nil
}
return StreamReplyBody{
MsgType: "stream",
Stream: stream,
}, WSCmdRespond, reqID, nil
}
func buildWelcomePayload(msg channel.Message) (any, bool) {
if !readBool(msg.Metadata, "wecom_welcome") {
return nil, false
}
if card, ok := readTemplateCard(msg.Metadata); ok {
return WelcomeTemplateCardReplyBody{
MsgType: "template_card",
TemplateCard: card,
}, true
}
text := strings.TrimSpace(msg.PlainText())
if text == "" {
return nil, false
}
return WelcomeTextReplyBody{
MsgType: "text",
Text: welcomeTextBody{
Content: text,
},
}, true
}
func readTemplateCard(metadata map[string]any) (map[string]any, bool) {
if metadata == nil {
return nil, false
}
raw, ok := metadata["wecom_template_card"]
if !ok {
return nil, false
}
card, ok := raw.(map[string]any)
if !ok || len(card) == 0 {
return nil, false
}
return card, true
}
func readUpdateTemplateCard(metadata map[string]any) (UpdateTemplateCardBody, bool) {
if metadata == nil {
return UpdateTemplateCardBody{}, false
}
raw, ok := metadata["wecom_update_template_card"]
if !ok {
return UpdateTemplateCardBody{}, false
}
card, ok := raw.(map[string]any)
if !ok || len(card) == 0 {
return UpdateTemplateCardBody{}, false
}
body := UpdateTemplateCardBody{
ResponseType: "update_template_card",
TemplateCard: card,
}
if userIDs := readStringSlice(metadata["wecom_update_userids"]); len(userIDs) > 0 {
body.UserIDs = userIDs
}
return body, true
}
func readStringSlice(raw any) []string {
switch v := raw.(type) {
case []string:
out := make([]string, 0, len(v))
for _, item := range v {
if s := strings.TrimSpace(item); s != "" {
out = append(out, s)
}
}
return out
case []any:
out := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok && strings.TrimSpace(s) != "" {
out = append(out, strings.TrimSpace(s))
}
}
return out
case string:
if strings.TrimSpace(v) == "" {
return nil
}
parts := strings.Split(v, ",")
out := make([]string, 0, len(parts))
for _, part := range parts {
if s := strings.TrimSpace(part); s != "" {
out = append(out, s)
}
}
return out
default:
return nil
}
}
func readFeedbackID(metadata map[string]any) string {
if metadata == nil {
return ""
}
raw, ok := metadata["wecom_feedback_id"]
if !ok {
return ""
}
if v, ok := raw.(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func readBool(metadata map[string]any, key string) bool {
if metadata == nil || strings.TrimSpace(key) == "" {
return false
}
raw, ok := metadata[key]
if !ok {
return false
}
v, ok := raw.(bool)
return ok && v
}
func extractBase64Content(v string) string {
value := strings.TrimSpace(v)
if value == "" {
return ""
}
if idx := strings.Index(value, ","); idx > 0 && strings.Contains(strings.ToLower(value[:idx]), "base64") {
return strings.TrimSpace(value[idx+1:])
}
return value
}
func (a *WeComAdapter) getClient(botID string) *WSClient {
key := strings.TrimSpace(botID)
if key == "" {
return nil
}
a.mu.RLock()
defer a.mu.RUnlock()
return a.clients[key]
}
func (a *WeComAdapter) lookupCallbackContext(reply *channel.ReplyRef) (callbackContext, bool) {
if a == nil || a.cache == nil || reply == nil {
return callbackContext{}, false
}
messageID := strings.TrimSpace(reply.MessageID)
if messageID == "" {
return callbackContext{}, false
}
return a.cache.Get(messageID)
}
func (a *WeComAdapter) ensureHTTPClient() {
a.mu.Lock()
defer a.mu.Unlock()
if a.http == nil {
a.http = NewHTTPClient(HTTPClientOptions{Logger: a.logger})
}
}
@@ -0,0 +1,239 @@
package wecom
import (
"testing"
"github.com/memohai/memoh/internal/channel"
)
func TestBuildSendPayload_Text(t *testing.T) {
payload, cmd, reqID, err := buildSendPayload(channel.Message{
Format: channel.MessageFormatPlain,
Text: "hello",
}, "chat_1")
if err != nil {
t.Fatalf("buildSendPayload error = %v", err)
}
if cmd != WSCmdSendMessage {
t.Fatalf("unexpected cmd: %q", cmd)
}
if reqID == "" {
t.Fatal("reqID should not be empty")
}
p, ok := payload.(SendMessageMarkdownBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.Markdown.Content != "hello" {
t.Fatalf("unexpected payload content: %q", p.Markdown.Content)
}
}
func TestBuildSendPayload_AttachmentNotSupported(t *testing.T) {
_, _, _, err := buildSendPayload(channel.Message{
Attachments: []channel.Attachment{
{Type: channel.AttachmentImage, Base64: "aGVsbG8="},
},
}, "chat_1")
if err == nil {
t.Fatal("expected error for proactive attachment send")
}
}
func TestBuildRespondPayload_Stream(t *testing.T) {
payload, cmd, reqID, err := buildRespondPayload(channel.Message{
Format: channel.MessageFormatMarkdown,
Text: "world",
}, "req_abc")
if err != nil {
t.Fatalf("buildRespondPayload error = %v", err)
}
if cmd != WSCmdRespond {
t.Fatalf("unexpected cmd: %q", cmd)
}
if reqID != "req_abc" {
t.Fatalf("unexpected req id: %q", reqID)
}
p, ok := payload.(StreamReplyBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.MsgType != "stream" || p.Stream.Content != "world" || !p.Stream.Finish {
t.Fatalf("unexpected stream payload: %+v", p)
}
}
func TestBuildRespondPayload_StreamWithFeedback(t *testing.T) {
payload, cmd, _, err := buildRespondPayload(channel.Message{
Text: "world",
Metadata: map[string]any{
"wecom_feedback_id": "feedback_1",
},
}, "req_abc")
if err != nil {
t.Fatalf("buildRespondPayload error = %v", err)
}
if cmd != WSCmdRespond {
t.Fatalf("unexpected cmd: %q", cmd)
}
p, ok := payload.(StreamReplyBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.Stream.Feedback == nil || p.Stream.Feedback.ID != "feedback_1" {
t.Fatalf("unexpected feedback: %+v", p.Stream.Feedback)
}
}
func TestBuildRespondPayloadWithStream_Delta(t *testing.T) {
payload, cmd, reqID, err := buildRespondPayloadWithStream(channel.Message{
Text: "delta",
}, "req_abc", "stream_1", false)
if err != nil {
t.Fatalf("buildRespondPayloadWithStream error = %v", err)
}
if cmd != WSCmdRespond || reqID != "req_abc" {
t.Fatalf("unexpected cmd/reqid: %q %q", cmd, reqID)
}
p, ok := payload.(StreamReplyBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.Stream.ID != "stream_1" || p.Stream.Finish {
t.Fatalf("unexpected stream payload: %+v", p.Stream)
}
}
func TestBuildRespondPayloadWithStream_EmptyDeltaRejected(t *testing.T) {
_, _, _, err := buildRespondPayloadWithStream(channel.Message{}, "req_abc", "stream_1", false)
if err == nil {
t.Fatal("expected error for empty delta content")
}
}
func TestBuildSendPayload_TemplateCard(t *testing.T) {
payload, cmd, _, err := buildSendPayload(channel.Message{
Metadata: map[string]any{
"wecom_template_card": map[string]any{
"card_type": "text_notice",
},
},
}, "chat_1")
if err != nil {
t.Fatalf("buildSendPayload error = %v", err)
}
if cmd != WSCmdSendMessage {
t.Fatalf("unexpected cmd: %q", cmd)
}
p, ok := payload.(SendMessageTemplateCardBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.MsgType != "template_card" {
t.Fatalf("unexpected msg type: %q", p.MsgType)
}
}
func TestBuildRespondPayload_StreamWithTemplateCard(t *testing.T) {
payload, cmd, _, err := buildRespondPayload(channel.Message{
Text: "x",
Metadata: map[string]any{
"wecom_template_card": map[string]any{
"card_type": "text_notice",
},
},
}, "req_abc")
if err != nil {
t.Fatalf("buildRespondPayload error = %v", err)
}
if cmd != WSCmdRespond {
t.Fatalf("unexpected cmd: %q", cmd)
}
p, ok := payload.(StreamWithTemplateCardReplyBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.MsgType != "stream_with_template_card" {
t.Fatalf("unexpected msg type: %q", p.MsgType)
}
}
func TestBuildRespondPayload_UpdateTemplateCard(t *testing.T) {
payload, cmd, reqID, err := buildRespondPayload(channel.Message{
Metadata: map[string]any{
"wecom_update_template_card": map[string]any{
"card_type": "text_notice",
"task_id": "task-1",
},
"wecom_update_userids": []any{"u1", "u2"},
},
}, "req_abc")
if err != nil {
t.Fatalf("buildRespondPayload error = %v", err)
}
if cmd != WSCmdRespondUpdate {
t.Fatalf("unexpected cmd: %q", cmd)
}
if reqID != "req_abc" {
t.Fatalf("unexpected req id: %q", reqID)
}
p, ok := payload.(UpdateTemplateCardBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.ResponseType != "update_template_card" {
t.Fatalf("unexpected response type: %q", p.ResponseType)
}
if len(p.UserIDs) != 2 || p.UserIDs[0] != "u1" || p.UserIDs[1] != "u2" {
t.Fatalf("unexpected userids: %#v", p.UserIDs)
}
}
func TestBuildRespondPayload_WelcomeText(t *testing.T) {
payload, cmd, reqID, err := buildRespondPayload(channel.Message{
Text: "welcome",
Metadata: map[string]any{
"wecom_welcome": true,
},
}, "req_abc")
if err != nil {
t.Fatalf("buildRespondPayload error = %v", err)
}
if cmd != WSCmdRespondWelcome {
t.Fatalf("unexpected cmd: %q", cmd)
}
if reqID != "req_abc" {
t.Fatalf("unexpected req id: %q", reqID)
}
p, ok := payload.(WelcomeTextReplyBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.MsgType != "text" || p.Text.Content != "welcome" {
t.Fatalf("unexpected welcome payload: %+v", p)
}
}
func TestBuildRespondPayload_WelcomeTemplateCard(t *testing.T) {
payload, cmd, _, err := buildRespondPayload(channel.Message{
Metadata: map[string]any{
"wecom_welcome": true,
"wecom_template_card": map[string]any{
"card_type": "text_notice",
},
},
}, "req_abc")
if err != nil {
t.Fatalf("buildRespondPayload error = %v", err)
}
if cmd != WSCmdRespondWelcome {
t.Fatalf("unexpected cmd: %q", cmd)
}
p, ok := payload.(WelcomeTemplateCardReplyBody)
if !ok {
t.Fatalf("unexpected payload type: %T", payload)
}
if p.MsgType != "template_card" {
t.Fatalf("unexpected msg type: %q", p.MsgType)
}
}
+256
View File
@@ -0,0 +1,256 @@
package wecom
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
)
const (
defaultWSURL = "wss://openws.work.weixin.qq.com"
)
const (
WSCmdSubscribe = "aibot_subscribe"
WSCmdHeartbeat = "ping"
WSCmdRespond = "aibot_respond_msg"
WSCmdRespondWelcome = "aibot_respond_welcome_msg"
WSCmdRespondUpdate = "aibot_respond_update_msg"
WSCmdSendMessage = "aibot_send_msg"
WSCmdMsgCallback = "aibot_msg_callback"
WSCmdEventCallback = "aibot_event_callback"
)
type WSHeaders struct {
ReqID string `json:"req_id"`
}
type WSFrame struct {
Cmd string `json:"cmd,omitempty"`
Headers WSHeaders `json:"headers"`
Body json.RawMessage `json:"body,omitempty"`
ErrCode int `json:"errcode,omitempty"`
ErrMsg string `json:"errmsg,omitempty"`
}
func (f WSFrame) DecodeBody(dst any) error {
if len(f.Body) == 0 {
return fmt.Errorf("wecom frame body is empty")
}
if dst == nil {
return fmt.Errorf("decode target is nil")
}
return json.Unmarshal(f.Body, dst)
}
func BuildFrame(cmd, reqID string, body any) (WSFrame, error) {
frame := WSFrame{
Cmd: cmd,
Headers: WSHeaders{
ReqID: strings.TrimSpace(reqID),
},
}
if frame.Headers.ReqID == "" {
return WSFrame{}, fmt.Errorf("req_id is required")
}
if body == nil {
return frame, nil
}
raw, err := json.Marshal(body)
if err != nil {
return WSFrame{}, err
}
frame.Body = raw
return frame, nil
}
func NewReqID(prefix string) string {
p := strings.TrimSpace(prefix)
if p == "" {
return uuid.NewString()
}
return p + "_" + uuid.NewString()
}
type AuthCredentials struct {
BotID string
Secret string
}
func (c AuthCredentials) Validate() error {
if strings.TrimSpace(c.BotID) == "" {
return fmt.Errorf("wecom bot_id is required")
}
if strings.TrimSpace(c.Secret) == "" {
return fmt.Errorf("wecom secret is required")
}
return nil
}
type SubscribeBody struct {
BotID string `json:"bot_id"`
Secret string `json:"secret"`
}
type CallbackFrom struct {
UserID string `json:"userid,omitempty"`
}
type MessageText struct {
Content string `json:"content,omitempty"`
}
type MessageImage struct {
URL string `json:"url,omitempty"`
AESKey string `json:"aeskey,omitempty"`
}
type MessageFile struct {
URL string `json:"url,omitempty"`
AESKey string `json:"aeskey,omitempty"`
FileName string `json:"file_name,omitempty"`
}
type MessageVoice struct {
Content string `json:"content,omitempty"`
}
type MessageVideo struct {
URL string `json:"url,omitempty"`
AESKey string `json:"aeskey,omitempty"`
}
type MessageMixedItem struct {
MsgType string `json:"msgtype,omitempty"`
Text *MessageText `json:"text,omitempty"`
Markdown *MessageText `json:"markdown,omitempty"`
Image *MessageImage `json:"image,omitempty"`
File *MessageFile `json:"file,omitempty"`
Voice *MessageVoice `json:"voice,omitempty"`
Video *MessageVideo `json:"video,omitempty"`
}
type MessageQuote struct {
MsgID string `json:"msgid,omitempty"`
}
type MessageCallbackBody struct {
MsgID string `json:"msgid,omitempty"`
AIBotID string `json:"aibotid,omitempty"`
ChatID string `json:"chatid,omitempty"`
ChatType string `json:"chattype,omitempty"`
From CallbackFrom `json:"from,omitempty"`
CreateTime int64 `json:"create_time,omitempty"`
MsgType string `json:"msgtype,omitempty"`
ResponseURL string `json:"response_url,omitempty"`
Text *MessageText `json:"text,omitempty"`
Markdown *MessageText `json:"markdown,omitempty"`
Image *MessageImage `json:"image,omitempty"`
File *MessageFile `json:"file,omitempty"`
Voice *MessageVoice `json:"voice,omitempty"`
Video *MessageVideo `json:"video,omitempty"`
Mixed []MessageMixedItem `json:"mixed,omitempty"`
Quote *MessageQuote `json:"quote,omitempty"`
}
type EventPayload struct {
EventType string `json:"event_type,omitempty"`
EventType2 string `json:"eventtype,omitempty"`
EventKey string `json:"event_key,omitempty"`
TaskID string `json:"task_id,omitempty"`
Code string `json:"code,omitempty"`
Reason string `json:"reason,omitempty"`
}
type EventTask struct {
TaskID string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
}
type EventCallbackBody struct {
MsgID string `json:"msgid,omitempty"`
AIBotID string `json:"aibotid,omitempty"`
ChatID string `json:"chatid,omitempty"`
ChatType string `json:"chattype,omitempty"`
From CallbackFrom `json:"from,omitempty"`
CreateTime int64 `json:"create_time,omitempty"`
MsgType string `json:"msgtype,omitempty"`
ResponseURL string `json:"response_url,omitempty"`
Event EventPayload `json:"event,omitempty"`
Task EventTask `json:"task,omitempty"`
}
type StreamReplyBody struct {
MsgType string `json:"msgtype"`
Stream StreamReplyBlock `json:"stream"`
}
type StreamReplyBlock struct {
ID string `json:"id"`
Finish bool `json:"finish,omitempty"`
Content string `json:"content,omitempty"`
MsgItems []StreamReplyItem `json:"msg_item,omitempty"`
Feedback *StreamReplyFeedback `json:"feedback,omitempty"`
}
type StreamReplyItem struct {
MsgType string `json:"msgtype"`
Image *StreamReplyImage `json:"image,omitempty"`
}
type StreamReplyImage struct {
Base64 string `json:"base64"`
MD5 string `json:"md5"`
}
type StreamReplyFeedback struct {
ID string `json:"id"`
}
type SendMessageMarkdownBody struct {
ChatID string `json:"chatid"`
MsgType string `json:"msgtype"`
Markdown markdownPayload `json:"markdown"`
}
type SendMessageTemplateCardBody struct {
ChatID string `json:"chatid"`
MsgType string `json:"msgtype"`
TemplateCard map[string]any `json:"template_card"`
}
type StreamWithTemplateCardReplyBody struct {
MsgType string `json:"msgtype"`
Stream StreamReplyBlock `json:"stream"`
TemplateCard map[string]any `json:"template_card"`
}
type WelcomeTextReplyBody struct {
MsgType string `json:"msgtype"`
Text welcomeTextBody `json:"text"`
}
type welcomeTextBody struct {
Content string `json:"content"`
}
type WelcomeTemplateCardReplyBody struct {
MsgType string `json:"msgtype"`
TemplateCard map[string]any `json:"template_card"`
}
type UpdateTemplateCardBody struct {
ResponseType string `json:"response_type"`
UserIDs []string `json:"userids,omitempty"`
TemplateCard map[string]any `json:"template_card"`
}
func unixMilliseconds(ts int64) time.Time {
if ts <= 0 {
return time.Time{}
}
return time.UnixMilli(ts)
}
@@ -0,0 +1,29 @@
package wecom
import "testing"
func TestBuildFrameAndDecodeBody(t *testing.T) {
type body struct {
BotID string `json:"bot_id"`
}
frame, err := BuildFrame(WSCmdSubscribe, "req-1", body{BotID: "bot123"})
if err != nil {
t.Fatalf("BuildFrame error = %v", err)
}
var decoded body
if err := frame.DecodeBody(&decoded); err != nil {
t.Fatalf("DecodeBody error = %v", err)
}
if decoded.BotID != "bot123" {
t.Fatalf("unexpected decoded body: %+v", decoded)
}
}
func TestAuthCredentialsValidate(t *testing.T) {
if err := (AuthCredentials{}).Validate(); err == nil {
t.Fatal("expected validation error for empty credentials")
}
if err := (AuthCredentials{BotID: "id", Secret: "sec"}).Validate(); err != nil {
t.Fatalf("unexpected validation error: %v", err)
}
}
+434
View File
@@ -0,0 +1,434 @@
package wecom
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/memohai/memoh/internal/channel"
)
const Type channel.ChannelType = "wecom"
type wsClientFactory func(opts WSClientOptions) *WSClient
type WeComAdapter struct {
logger *slog.Logger
mu sync.RWMutex
clients map[string]*WSClient
http *HTTPClient
cache *callbackContextCache
newWSClient wsClientFactory
}
func NewWeComAdapter(log *slog.Logger) *WeComAdapter {
if log == nil {
log = slog.Default()
}
return &WeComAdapter{
logger: log.With(slog.String("adapter", "wecom")),
clients: make(map[string]*WSClient),
http: NewHTTPClient(HTTPClientOptions{Logger: log}),
cache: newCallbackContextCache(24 * time.Hour),
newWSClient: func(opts WSClientOptions) *WSClient { return NewWSClient(opts) },
}
}
func (a *WeComAdapter) Type() channel.ChannelType { return Type }
func (a *WeComAdapter) Descriptor() channel.Descriptor {
return channel.Descriptor{
Type: Type,
DisplayName: "WeCom",
Capabilities: channel.ChannelCapabilities{
Text: true,
Markdown: true,
Attachments: true,
Media: true,
Reply: true,
Streaming: true,
BlockStreaming: true,
ChatTypes: []string{"private", "group"},
},
ConfigSchema: channel.ConfigSchema{
Version: 1,
Fields: map[string]channel.FieldSchema{
"botId": {Type: channel.FieldString, Required: true, Title: "Bot ID"},
"secret": {Type: channel.FieldSecret, Required: true, Title: "Secret"},
"wsUrl": {Type: channel.FieldString, Title: "WebSocket URL", Example: defaultWSURL},
"heartbeatSeconds": {Type: channel.FieldNumber, Title: "Heartbeat Seconds"},
"ackTimeoutSeconds": {Type: channel.FieldNumber, Title: "Ack Timeout Seconds"},
"writeTimeoutSeconds": {Type: channel.FieldNumber, Title: "Write Timeout Seconds"},
"readTimeoutSeconds": {Type: channel.FieldNumber, Title: "Read Timeout Seconds"},
},
},
UserConfigSchema: channel.ConfigSchema{
Version: 1,
Fields: map[string]channel.FieldSchema{
"chat_id": {Type: channel.FieldString},
"user_id": {Type: channel.FieldString},
},
},
TargetSpec: channel.TargetSpec{
Format: "chat_id:xxx | user_id:xxx",
Hints: []channel.TargetHint{
{Label: "Chat ID", Example: "chat_id:wrk_abc"},
{Label: "User ID", Example: "user_id:zhangsan"},
},
},
}
}
func (a *WeComAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) {
return normalizeConfig(raw)
}
func (a *WeComAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) {
return normalizeUserConfig(raw)
}
func (a *WeComAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) }
func (a *WeComAdapter) ResolveTarget(userConfig map[string]any) (string, error) {
return resolveTarget(userConfig)
}
func (a *WeComAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool {
return matchBinding(config, criteria)
}
func (a *WeComAdapter) BuildUserConfig(identity channel.Identity) map[string]any {
return buildUserConfig(identity)
}
func (a *WeComAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) {
_ = ctx
cfg, err := parseConfig(credentials)
if err != nil {
return nil, "", err
}
externalID := strings.TrimSpace(cfg.BotID)
identity := map[string]any{
"bot_id": externalID,
"aibot_id": externalID,
}
return identity, externalID, nil
}
func (a *WeComAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) {
parsed, err := parseConfig(cfg.Credentials)
if err != nil {
return nil, err
}
client := a.newWSClient(WSClientOptions{
URL: parsed.WSURL,
Logger: a.logger,
HeartbeatInterval: time.Duration(secondsOrDefault(parsed.HeartbeatSeconds, 30)) * time.Second,
AckTimeout: time.Duration(secondsOrDefault(parsed.AckTimeoutSeconds, 8)) * time.Second,
WriteTimeout: time.Duration(secondsOrDefault(parsed.WriteTimeoutSeconds, 8)) * time.Second,
ReadTimeout: time.Duration(secondsOrDefault(parsed.ReadTimeoutSeconds, 70)) * time.Second,
ReconnectBaseDelay: 1 * time.Second,
ReconnectMaxDelay: 30 * time.Second,
})
key := strings.TrimSpace(parsed.BotID)
a.mu.Lock()
a.clients[key] = client
a.mu.Unlock()
connCtx, cancel := context.WithCancel(ctx)
done := make(chan struct{})
go func() {
defer close(done)
err := client.Run(connCtx, AuthCredentials{
BotID: parsed.BotID,
Secret: parsed.Secret,
}, func(frameCtx context.Context, frame WSFrame) error {
return a.handleFrame(frameCtx, cfg, frame, handler)
})
if err != nil && connCtx.Err() == nil {
a.logger.Error("wecom websocket stopped",
slog.String("config_id", cfg.ID),
slog.Any("error", err),
)
}
}()
stop := func(context.Context) error {
cancel()
_ = client.Close()
<-done
a.mu.Lock()
if current, ok := a.clients[key]; ok && current == client {
delete(a.clients, key)
}
a.mu.Unlock()
return nil
}
return channel.NewConnection(cfg, stop), nil
}
func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error {
targetKind, targetID, ok := parseTarget(msg.Target)
if !ok {
return fmt.Errorf("wecom target is required")
}
parsed, err := parseConfig(cfg.Credentials)
if err != nil {
return err
}
client := a.getClient(parsed.BotID)
if client == nil {
return fmt.Errorf("wecom connection is not active")
}
if msg.Message.IsEmpty() {
return fmt.Errorf("message is required")
}
var (
payload any
cmd string
reqID string
buildErr error
)
if ctxMeta, ok := a.lookupCallbackContext(msg.Message.Reply); ok {
payload, cmd, reqID, buildErr = buildRespondPayload(msg.Message, ctxMeta.ReqID)
} else {
_ = targetKind
payload, cmd, reqID, buildErr = buildSendPayload(msg.Message, targetID)
}
if buildErr != nil {
return buildErr
}
ack, err := client.Reply(ctx, reqID, cmd, payload)
if err != nil {
return err
}
if ack.ErrCode != 0 {
return fmt.Errorf("wecom send failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
}
return nil
}
func (a *WeComAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) {
target = strings.TrimSpace(target)
if target == "" {
return nil, fmt.Errorf("wecom target is required")
}
reply := opts.Reply
if reply == nil && strings.TrimSpace(opts.SourceMessageID) != "" {
reply = &channel.ReplyRef{
Target: target,
MessageID: strings.TrimSpace(opts.SourceMessageID),
}
}
return &wecomOutboundStream{
adapter: a,
cfg: cfg,
target: target,
reply: reply,
}, nil
}
type wecomOutboundStream struct {
adapter *WeComAdapter
cfg channel.ChannelConfig
target string
reply *channel.ReplyRef
mu sync.Mutex
closed atomic.Bool
finalSent atomic.Bool
textBuilder strings.Builder
attachments []channel.Attachment
final *channel.Message
streamID string
lastPreview string
}
func (s *wecomOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error {
if s.adapter == nil {
return errors.New("wecom stream not configured")
}
if s.closed.Load() {
return errors.New("wecom stream is closed")
}
if s.finalSent.Load() {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
switch event.Type {
case channel.StreamEventStatus,
channel.StreamEventPhaseStart,
channel.StreamEventPhaseEnd,
channel.StreamEventToolCallStart,
channel.StreamEventToolCallEnd,
channel.StreamEventAgentStart,
channel.StreamEventAgentEnd,
channel.StreamEventProcessingStarted,
channel.StreamEventProcessingCompleted,
channel.StreamEventProcessingFailed:
return nil
case channel.StreamEventDelta:
if strings.TrimSpace(event.Delta) == "" || event.Phase == channel.StreamPhaseReasoning {
return nil
}
s.mu.Lock()
s.textBuilder.WriteString(event.Delta)
s.mu.Unlock()
return s.pushPreview(ctx)
case channel.StreamEventAttachment:
if len(event.Attachments) == 0 {
return nil
}
s.mu.Lock()
s.attachments = append(s.attachments, event.Attachments...)
s.mu.Unlock()
return nil
case channel.StreamEventFinal:
if event.Final == nil {
return nil
}
s.mu.Lock()
final := event.Final.Message
s.final = &final
s.mu.Unlock()
return s.flush(ctx)
case channel.StreamEventError:
text := strings.TrimSpace(event.Error)
if text == "" {
return nil
}
s.mu.Lock()
s.final = &channel.Message{Format: channel.MessageFormatPlain, Text: "Error: " + text}
s.mu.Unlock()
return s.flush(ctx)
}
return nil
}
func (s *wecomOutboundStream) Close(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
s.closed.Store(true)
if s.finalSent.Load() {
return nil
}
return s.flush(ctx)
}
func (s *wecomOutboundStream) flush(ctx context.Context) error {
if s.finalSent.Load() {
return nil
}
msg, streamID := s.snapshotMessage(true)
if msg.IsEmpty() {
return nil
}
if ctxMeta, ok := s.adapter.lookupCallbackContext(msg.Reply); ok {
if err := s.adapter.sendRespondStream(ctx, s.cfg, msg, ctxMeta.ReqID, streamID, true); err != nil {
return err
}
s.finalSent.Store(true)
return nil
}
if err := s.adapter.Send(ctx, s.cfg, channel.OutboundMessage{
Target: s.target,
Message: msg,
}); err != nil {
return err
}
s.finalSent.Store(true)
return nil
}
func (s *wecomOutboundStream) pushPreview(ctx context.Context) error {
if s.finalSent.Load() {
return nil
}
msg, streamID := s.snapshotMessage(false)
text := strings.TrimSpace(msg.PlainText())
if text == "" {
return nil
}
s.mu.Lock()
if s.lastPreview == text {
s.mu.Unlock()
return nil
}
s.mu.Unlock()
if ctxMeta, ok := s.adapter.lookupCallbackContext(msg.Reply); ok {
if err := s.adapter.sendRespondStream(ctx, s.cfg, msg, ctxMeta.ReqID, streamID, false); err != nil {
return err
}
s.mu.Lock()
s.lastPreview = text
s.mu.Unlock()
}
return nil
}
func (s *wecomOutboundStream) snapshotMessage(includeAttachments bool) (channel.Message, string) {
s.mu.Lock()
defer s.mu.Unlock()
msg := channel.Message{}
if s.final != nil {
msg = *s.final
}
if strings.TrimSpace(msg.Text) == "" {
msg.Text = strings.TrimSpace(s.textBuilder.String())
}
if includeAttachments && len(msg.Attachments) == 0 && len(s.attachments) > 0 {
msg.Attachments = append(msg.Attachments, s.attachments...)
}
if msg.Reply == nil && s.reply != nil {
msg.Reply = s.reply
}
if s.streamID == "" {
s.streamID = NewReqID("stream")
}
return msg, s.streamID
}
func (a *WeComAdapter) sendRespondStream(ctx context.Context, cfg channel.ChannelConfig, msg channel.Message, reqID string, streamID string, finish bool) error {
parsed, err := parseConfig(cfg.Credentials)
if err != nil {
return err
}
client := a.getClient(parsed.BotID)
if client == nil {
return fmt.Errorf("wecom connection is not active")
}
payload, cmd, ackReqID, err := buildRespondPayloadWithStream(msg, reqID, streamID, finish)
if err != nil {
return err
}
ack, err := client.Reply(ctx, ackReqID, cmd, payload)
if err != nil {
return err
}
if ack.ErrCode != 0 {
return fmt.Errorf("wecom send failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
}
return nil
}
func secondsOrDefault(value int, fallback int) int {
if value > 0 {
return value
}
return fallback
}
@@ -0,0 +1,52 @@
package wecom
import (
"context"
"testing"
"github.com/memohai/memoh/internal/channel"
)
func TestDiscoverSelf(t *testing.T) {
adapter := NewWeComAdapter(nil)
identity, externalID, err := adapter.DiscoverSelf(context.Background(), map[string]any{
"botId": "bot_123",
"secret": "sec",
})
if err != nil {
t.Fatalf("DiscoverSelf error = %v", err)
}
if externalID != "bot_123" {
t.Fatalf("unexpected external id: %q", externalID)
}
if identity["bot_id"] != "bot_123" {
t.Fatalf("unexpected bot_id: %v", identity["bot_id"])
}
if identity["aibot_id"] != "bot_123" {
t.Fatalf("unexpected aibot_id: %v", identity["aibot_id"])
}
if _, ok := identity["name"]; ok {
t.Fatalf("unexpected name field: %v", identity["name"])
}
if _, ok := identity["display_name"]; ok {
t.Fatalf("unexpected display_name field: %v", identity["display_name"])
}
}
func TestOpenStream_FallbackReplyFromSourceMessageID(t *testing.T) {
adapter := NewWeComAdapter(nil)
stream, err := adapter.OpenStream(context.Background(), channel.ChannelConfig{}, "chat_id:chat_1", channel.StreamOptions{
SourceMessageID: "msg_1",
})
if err != nil {
t.Fatalf("OpenStream error = %v", err)
}
ws, ok := stream.(*wecomOutboundStream)
if !ok {
t.Fatalf("unexpected stream type: %T", stream)
}
if ws.reply == nil || ws.reply.MessageID != "msg_1" || ws.reply.Target != "chat_id:chat_1" {
t.Fatalf("unexpected reply fallback: %+v", ws.reply)
}
}
@@ -0,0 +1,396 @@
package wecom
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
type WSClientOptions struct {
URL string
Dialer *websocket.Dialer
Logger *slog.Logger
AckTimeout time.Duration
WriteTimeout time.Duration
ReadTimeout time.Duration
HeartbeatInterval time.Duration
ReconnectBaseDelay time.Duration
ReconnectMaxDelay time.Duration
MaxReconnectAttempts int
}
type WSClient struct {
opts WSClientOptions
logger *slog.Logger
writeMu sync.Mutex
waitMu sync.Mutex
connMu sync.RWMutex
conn *websocket.Conn
waiters map[string]chan wsAck
closed bool
}
type wsAck struct {
frame WSFrame
err error
}
func NewWSClient(opts WSClientOptions) *WSClient {
if strings.TrimSpace(opts.URL) == "" {
opts.URL = defaultWSURL
}
if opts.Logger == nil {
opts.Logger = slog.Default()
}
if opts.AckTimeout <= 0 {
opts.AckTimeout = 8 * time.Second
}
if opts.WriteTimeout <= 0 {
opts.WriteTimeout = 8 * time.Second
}
if opts.ReadTimeout <= 0 {
opts.ReadTimeout = 70 * time.Second
}
if opts.HeartbeatInterval <= 0 {
opts.HeartbeatInterval = 30 * time.Second
}
if opts.ReconnectBaseDelay <= 0 {
opts.ReconnectBaseDelay = 1 * time.Second
}
if opts.ReconnectMaxDelay <= 0 {
opts.ReconnectMaxDelay = 30 * time.Second
}
return &WSClient{
opts: opts,
logger: opts.Logger.With(slog.String("component", "wecom_ws_client")),
waiters: make(map[string]chan wsAck),
}
}
func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
if err := auth.Validate(); err != nil {
return err
}
attempt := 0
for {
err := c.runSession(ctx, auth, onFrame)
if ctx.Err() != nil {
return ctx.Err()
}
if c.isClosed() {
return nil
}
if c.opts.MaxReconnectAttempts >= 0 && attempt >= c.opts.MaxReconnectAttempts {
if err == nil {
return fmt.Errorf("wecom websocket reconnect attempts exceeded")
}
return err
}
delay := c.backoff(attempt)
attempt++
c.logger.Warn("wecom websocket session ended; reconnecting",
slog.Int("attempt", attempt),
slog.Duration("delay", delay),
slog.Any("error", err),
)
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-timer.C:
}
}
}
func (c *WSClient) runSession(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
conn, _, err := c.dial(ctx)
if err != nil {
return err
}
c.setConn(conn)
sessionCtx, cancel := context.WithCancel(ctx)
defer func() {
cancel()
c.clearConn()
c.failAllWaiters(fmt.Errorf("wecom websocket disconnected"))
}()
readErrCh := make(chan error, 1)
go c.readLoop(sessionCtx, onFrame, readErrCh)
if err := c.authenticate(sessionCtx, auth); err != nil {
_ = conn.Close()
return err
}
c.logger.Info("wecom websocket authenticated")
go c.heartbeatLoop(sessionCtx)
select {
case <-sessionCtx.Done():
_ = conn.Close()
return sessionCtx.Err()
case err := <-readErrCh:
_ = conn.Close()
return err
}
}
func (c *WSClient) dial(ctx context.Context) (*websocket.Conn, *http.Response, error) {
dialer := c.opts.Dialer
if dialer == nil {
dialer = websocket.DefaultDialer
}
return dialer.DialContext(ctx, c.opts.URL, nil)
}
func (c *WSClient) authenticate(ctx context.Context, auth AuthCredentials) error {
frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), SubscribeBody{
BotID: strings.TrimSpace(auth.BotID),
Secret: strings.TrimSpace(auth.Secret),
})
if err != nil {
return err
}
ack, err := c.SendWithAck(ctx, frame)
if err != nil {
return err
}
if ack.ErrCode != 0 {
return fmt.Errorf("wecom subscribe failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
}
return nil
}
func (c *WSClient) heartbeatLoop(ctx context.Context) {
ticker := time.NewTicker(c.opts.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
frame, err := BuildFrame(WSCmdHeartbeat, NewReqID(WSCmdHeartbeat), nil)
if err != nil {
c.logger.Error("build heartbeat frame failed", slog.Any("error", err))
continue
}
err = c.Send(ctx, frame)
if err != nil {
c.logger.Warn("wecom websocket heartbeat failed", slog.Any("error", err))
if conn := c.getConn(); conn != nil {
_ = conn.Close()
}
return
}
}
}
}
func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, WSFrame) error, errCh chan<- error) {
conn := c.getConn()
if conn == nil {
errCh <- fmt.Errorf("wecom websocket connection not ready")
return
}
for {
if c.opts.ReadTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout))
}
_, payload, err := conn.ReadMessage()
if err != nil {
errCh <- err
return
}
var frame WSFrame
if err := json.Unmarshal(payload, &frame); err != nil {
c.logger.Warn("decode websocket frame failed", slog.Any("error", err))
continue
}
if c.dispatchAck(frame) {
continue
}
if onFrame == nil {
continue
}
if err := onFrame(ctx, frame); err != nil {
c.logger.Warn("wecom onFrame callback returned error", slog.Any("error", err))
}
}
}
func (c *WSClient) Send(ctx context.Context, frame WSFrame) error {
if strings.TrimSpace(frame.Headers.ReqID) == "" {
return fmt.Errorf("req_id is required")
}
conn := c.getConn()
if conn == nil {
return fmt.Errorf("wecom websocket is not connected")
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.opts.WriteTimeout > 0 {
_ = conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout))
}
if err := conn.WriteJSON(frame); err != nil {
return err
}
return nil
}
func (c *WSClient) SendWithAck(ctx context.Context, frame WSFrame) (WSFrame, error) {
reqID := strings.TrimSpace(frame.Headers.ReqID)
if reqID == "" {
return WSFrame{}, fmt.Errorf("req_id is required")
}
wait := make(chan wsAck, 1)
c.waitMu.Lock()
c.waiters[reqID] = wait
c.waitMu.Unlock()
defer func() {
c.waitMu.Lock()
delete(c.waiters, reqID)
c.waitMu.Unlock()
}()
if err := c.Send(ctx, frame); err != nil {
return WSFrame{}, err
}
timeout := c.opts.AckTimeout
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
if remaining > 0 && remaining < timeout {
timeout = remaining
}
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-ctx.Done():
return WSFrame{}, ctx.Err()
case <-timer.C:
return WSFrame{}, fmt.Errorf("wait websocket ack timeout for req_id=%s", reqID)
case ack := <-wait:
if ack.err != nil {
return WSFrame{}, ack.err
}
return ack.frame, nil
}
}
func (c *WSClient) Close() error {
c.connMu.Lock()
c.closed = true
conn := c.conn
c.conn = nil
c.connMu.Unlock()
c.failAllWaiters(fmt.Errorf("wecom websocket client closed"))
if conn == nil {
return nil
}
return conn.Close()
}
func (c *WSClient) Reply(ctx context.Context, reqID string, cmd string, body any) (WSFrame, error) {
frame, err := BuildFrame(cmd, reqID, body)
if err != nil {
return WSFrame{}, err
}
// WeCom callback reply commands are triggered by inbound callback req_id and may
// not always return an explicit ACK frame in production. Waiting for ACK here can
// cause false timeouts even when the platform accepts the reply.
if isRespondCommand(cmd) {
if err := c.Send(ctx, frame); err != nil {
return WSFrame{}, err
}
return WSFrame{}, nil
}
return c.SendWithAck(ctx, frame)
}
func isRespondCommand(cmd string) bool {
switch strings.TrimSpace(cmd) {
case WSCmdRespond, WSCmdRespondWelcome, WSCmdRespondUpdate:
return true
default:
return false
}
}
func (c *WSClient) dispatchAck(frame WSFrame) bool {
reqID := strings.TrimSpace(frame.Headers.ReqID)
if reqID == "" {
return false
}
c.waitMu.Lock()
wait, ok := c.waiters[reqID]
c.waitMu.Unlock()
if !ok {
return false
}
select {
case wait <- wsAck{frame: frame}:
default:
}
return true
}
func (c *WSClient) failAllWaiters(cause error) {
c.waitMu.Lock()
defer c.waitMu.Unlock()
for id, wait := range c.waiters {
delete(c.waiters, id)
select {
case wait <- wsAck{err: cause}:
default:
}
}
}
func (c *WSClient) backoff(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := c.opts.ReconnectBaseDelay << attempt
if delay > c.opts.ReconnectMaxDelay {
return c.opts.ReconnectMaxDelay
}
return delay
}
func (c *WSClient) setConn(conn *websocket.Conn) {
c.connMu.Lock()
defer c.connMu.Unlock()
c.conn = conn
}
func (c *WSClient) clearConn() {
c.connMu.Lock()
conn := c.conn
c.conn = nil
c.connMu.Unlock()
if conn != nil {
_ = conn.Close()
}
}
func (c *WSClient) getConn() *websocket.Conn {
c.connMu.RLock()
defer c.connMu.RUnlock()
return c.conn
}
func (c *WSClient) isClosed() bool {
c.connMu.RLock()
defer c.connMu.RUnlock()
return c.closed
}
@@ -0,0 +1,178 @@
package wecom
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/websocket"
)
func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) {
t.Parallel()
var connCount atomic.Int32
callbackSent := make(chan struct{}, 1)
upgrader := websocket.Upgrader{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
n := connCount.Add(1)
var subscribeFrame WSFrame
if err := conn.ReadJSON(&subscribeFrame); err != nil {
return
}
_ = conn.WriteJSON(WSFrame{
Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID},
ErrCode: 0,
})
if n == 1 {
return
}
body, _ := json.Marshal(MessageCallbackBody{
MsgID: "m1",
ChatID: "chat_1",
ChatType: "group",
CreateTime: time.Now().UnixMilli(),
From: CallbackFrom{UserID: "u1"},
MsgType: "text",
Text: &MessageText{Content: "hello"},
})
_ = conn.WriteJSON(WSFrame{
Cmd: WSCmdMsgCallback,
Headers: WSHeaders{ReqID: "cb_req"},
Body: body,
})
select {
case callbackSent <- struct{}{}:
default:
}
<-time.After(100 * time.Millisecond)
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
client := NewWSClient(WSClientOptions{
URL: wsURL,
AckTimeout: 200 * time.Millisecond,
HeartbeatInterval: 10 * time.Second,
ReconnectBaseDelay: 10 * time.Millisecond,
ReconnectMaxDelay: 20 * time.Millisecond,
MaxReconnectAttempts: 5,
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, func(context.Context, WSFrame) error {
cancel()
return nil
})
}()
select {
case <-callbackSent:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting callback on reconnected session")
}
select {
case err := <-runErrCh:
if err == nil || err != context.Canceled {
t.Fatalf("unexpected run error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting run return")
}
if connCount.Load() < 2 {
t.Fatalf("expected reconnect attempts >= 2, got %d", connCount.Load())
}
}
func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) {
t.Parallel()
var connCount atomic.Int32
heartbeatSeen := make(chan struct{}, 1)
upgrader := websocket.Upgrader{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
n := connCount.Add(1)
var subscribeFrame WSFrame
if err := conn.ReadJSON(&subscribeFrame); err != nil {
return
}
_ = conn.WriteJSON(WSFrame{
Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID},
ErrCode: 0,
})
if n > 1 {
return
}
var heartbeatFrame WSFrame
if err := conn.ReadJSON(&heartbeatFrame); err != nil {
return
}
if heartbeatFrame.Cmd == WSCmdHeartbeat {
select {
case heartbeatSeen <- struct{}{}:
default:
}
}
<-time.After(200 * time.Millisecond)
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
client := NewWSClient(WSClientOptions{
URL: wsURL,
HeartbeatInterval: 20 * time.Millisecond,
ReconnectBaseDelay: 10 * time.Millisecond,
ReconnectMaxDelay: 20 * time.Millisecond,
MaxReconnectAttempts: 5,
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, nil)
}()
select {
case <-heartbeatSeen:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting heartbeat frame")
}
<-time.After(250 * time.Millisecond)
cancel()
select {
case err := <-runErrCh:
if err == nil || err != context.Canceled {
t.Fatalf("unexpected run error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting run return")
}
if connCount.Load() < 1 {
t.Fatalf("expected at least one session, got %d", connCount.Load())
}
}