mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(wecom): align adapter with channel stream behavior
Migrate the imported WeCom adapter to current channel interfaces and stabilize stream delivery by preventing heartbeat/reply ACK timeout regressions and post-final overwrite updates.
This commit is contained in:
@@ -0,0 +1,157 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upgrader := websocket.Upgrader{}
|
||||
receivedRespond := make(chan WSFrame, 1)
|
||||
serverErr := make(chan error, 1)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
select {
|
||||
case serverErr <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var subscribeFrame WSFrame
|
||||
if err := conn.ReadJSON(&subscribeFrame); err != nil {
|
||||
select {
|
||||
case serverErr <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := conn.WriteJSON(WSFrame{
|
||||
Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID},
|
||||
ErrCode: 0,
|
||||
}); err != nil {
|
||||
select {
|
||||
case serverErr <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(MessageCallbackBody{
|
||||
MsgID: "msg_1",
|
||||
ChatID: "chat_1",
|
||||
ChatType: "group",
|
||||
CreateTime: time.Now().UnixMilli(),
|
||||
From: CallbackFrom{UserID: "u1"},
|
||||
MsgType: "text",
|
||||
ResponseURL: "https://example.com/resp",
|
||||
Text: &MessageText{Content: "hello"},
|
||||
})
|
||||
if err := conn.WriteJSON(WSFrame{
|
||||
Cmd: WSCmdMsgCallback,
|
||||
Headers: WSHeaders{ReqID: "callback_req_id"},
|
||||
Body: body,
|
||||
}); err != nil {
|
||||
select {
|
||||
case serverErr <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var respondFrame WSFrame
|
||||
if err := conn.ReadJSON(&respondFrame); err != nil {
|
||||
select {
|
||||
case serverErr <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case receivedRespond <- respondFrame:
|
||||
default:
|
||||
}
|
||||
_ = conn.WriteJSON(WSFrame{
|
||||
Headers: WSHeaders{ReqID: respondFrame.Headers.ReqID},
|
||||
ErrCode: 0,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
adapter := NewWeComAdapter(nil)
|
||||
cfg := channel.ChannelConfig{
|
||||
Credentials: map[string]any{
|
||||
"botId": "bot",
|
||||
"secret": "sec",
|
||||
"wsUrl": "ws" + strings.TrimPrefix(server.URL, "http"),
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
inboundCh := make(chan channel.InboundMessage, 1)
|
||||
conn, err := adapter.Connect(ctx, cfg, func(_ context.Context, _ channel.ChannelConfig, msg channel.InboundMessage) error {
|
||||
select {
|
||||
case inboundCh <- msg:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("connect error: %v", err)
|
||||
}
|
||||
defer conn.Stop(context.Background())
|
||||
|
||||
select {
|
||||
case inbound := <-inboundCh:
|
||||
if inbound.Message.ID != "msg_1" {
|
||||
t.Fatalf("unexpected inbound message id: %s", inbound.Message.ID)
|
||||
}
|
||||
case err := <-serverErr:
|
||||
t.Fatalf("server error: %v", err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting inbound callback")
|
||||
}
|
||||
|
||||
err = adapter.Send(context.Background(), cfg, channel.OutboundMessage{
|
||||
Target: "chat_id:chat_1",
|
||||
Message: channel.Message{
|
||||
Format: channel.MessageFormatPlain,
|
||||
Text: "reply content",
|
||||
Reply: &channel.ReplyRef{
|
||||
MessageID: "msg_1",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("send error: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case frame := <-receivedRespond:
|
||||
if frame.Cmd != WSCmdRespond {
|
||||
t.Fatalf("expected cmd=%s got=%s", WSCmdRespond, frame.Cmd)
|
||||
}
|
||||
if frame.Headers.ReqID != "callback_req_id" {
|
||||
t.Fatalf("expected req_id callback_req_id got=%s", frame.Headers.ReqID)
|
||||
}
|
||||
case err := <-serverErr:
|
||||
t.Fatalf("server error: %v", err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting respond frame")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type callbackContext struct {
|
||||
ReqID string
|
||||
ResponseURL string
|
||||
ChatID string
|
||||
UserID string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type callbackContextCache struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]callbackContext
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func newCallbackContextCache(ttl time.Duration) *callbackContextCache {
|
||||
if ttl <= 0 {
|
||||
ttl = 24 * time.Hour
|
||||
}
|
||||
return &callbackContextCache{
|
||||
items: make(map[string]callbackContext),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *callbackContextCache) Put(messageID string, ctx callbackContext) {
|
||||
key := strings.TrimSpace(messageID)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if ctx.CreatedAt.IsZero() {
|
||||
ctx.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.items[key] = ctx
|
||||
c.gcLocked()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *callbackContextCache) Get(messageID string) (callbackContext, bool) {
|
||||
key := strings.TrimSpace(messageID)
|
||||
if key == "" {
|
||||
return callbackContext{}, false
|
||||
}
|
||||
c.mu.RLock()
|
||||
item, ok := c.items[key]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return callbackContext{}, false
|
||||
}
|
||||
if time.Since(item.CreatedAt) > c.ttl {
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.mu.Unlock()
|
||||
return callbackContext{}, false
|
||||
}
|
||||
return item, true
|
||||
}
|
||||
|
||||
func (c *callbackContextCache) gcLocked() {
|
||||
if len(c.items) < 512 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
for key, item := range c.items {
|
||||
if now.Sub(item.CreatedAt) > c.ttl {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCallbackContextCache_PutGet(t *testing.T) {
|
||||
cache := newCallbackContextCache(1 * time.Hour)
|
||||
cache.Put("m1", callbackContext{ReqID: "r1"})
|
||||
got, ok := cache.Get("m1")
|
||||
if !ok {
|
||||
t.Fatal("expected cache hit")
|
||||
}
|
||||
if got.ReqID != "r1" {
|
||||
t.Fatalf("unexpected req id: %q", got.ReqID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackContextCache_Expires(t *testing.T) {
|
||||
cache := newCallbackContextCache(1 * time.Second)
|
||||
cache.Put("m1", callbackContext{
|
||||
ReqID: "r1",
|
||||
CreatedAt: time.Now().Add(-2 * time.Second),
|
||||
})
|
||||
if _, ok := cache.Get("m1"); ok {
|
||||
t.Fatal("expected cache miss due to expiry")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
BotID string
|
||||
Secret string
|
||||
WSURL string
|
||||
HeartbeatSeconds int
|
||||
AckTimeoutSeconds int
|
||||
WriteTimeoutSeconds int
|
||||
ReadTimeoutSeconds int
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
ChatID string
|
||||
UserID string
|
||||
}
|
||||
|
||||
func normalizeConfig(raw map[string]any) (map[string]any, error) {
|
||||
cfg, err := parseConfig(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := map[string]any{
|
||||
"botId": cfg.BotID,
|
||||
"secret": cfg.Secret,
|
||||
}
|
||||
if cfg.WSURL != "" {
|
||||
out["wsUrl"] = cfg.WSURL
|
||||
}
|
||||
if cfg.HeartbeatSeconds > 0 {
|
||||
out["heartbeatSeconds"] = cfg.HeartbeatSeconds
|
||||
}
|
||||
if cfg.AckTimeoutSeconds > 0 {
|
||||
out["ackTimeoutSeconds"] = cfg.AckTimeoutSeconds
|
||||
}
|
||||
if cfg.WriteTimeoutSeconds > 0 {
|
||||
out["writeTimeoutSeconds"] = cfg.WriteTimeoutSeconds
|
||||
}
|
||||
if cfg.ReadTimeoutSeconds > 0 {
|
||||
out["readTimeoutSeconds"] = cfg.ReadTimeoutSeconds
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func normalizeUserConfig(raw map[string]any) (map[string]any, error) {
|
||||
cfg, err := parseUserConfig(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := map[string]any{}
|
||||
if cfg.ChatID != "" {
|
||||
out["chat_id"] = cfg.ChatID
|
||||
}
|
||||
if cfg.UserID != "" {
|
||||
out["user_id"] = cfg.UserID
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func parseConfig(raw map[string]any) (Config, error) {
|
||||
cfg := Config{
|
||||
BotID: strings.TrimSpace(channel.ReadString(raw, "botId", "bot_id")),
|
||||
Secret: strings.TrimSpace(channel.ReadString(raw, "secret")),
|
||||
WSURL: strings.TrimSpace(channel.ReadString(raw, "wsUrl", "ws_url")),
|
||||
}
|
||||
if value, ok := readInt(raw, "heartbeatSeconds", "heartbeat_seconds"); ok {
|
||||
cfg.HeartbeatSeconds = value
|
||||
}
|
||||
if value, ok := readInt(raw, "ackTimeoutSeconds", "ack_timeout_seconds"); ok {
|
||||
cfg.AckTimeoutSeconds = value
|
||||
}
|
||||
if value, ok := readInt(raw, "writeTimeoutSeconds", "write_timeout_seconds"); ok {
|
||||
cfg.WriteTimeoutSeconds = value
|
||||
}
|
||||
if value, ok := readInt(raw, "readTimeoutSeconds", "read_timeout_seconds"); ok {
|
||||
cfg.ReadTimeoutSeconds = value
|
||||
}
|
||||
if cfg.BotID == "" || cfg.Secret == "" {
|
||||
return Config{}, fmt.Errorf("wecom botId and secret are required")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func parseUserConfig(raw map[string]any) (UserConfig, error) {
|
||||
cfg := UserConfig{
|
||||
ChatID: strings.TrimSpace(channel.ReadString(raw, "chatId", "chat_id")),
|
||||
UserID: strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")),
|
||||
}
|
||||
if cfg.ChatID == "" && cfg.UserID == "" {
|
||||
return UserConfig{}, fmt.Errorf("wecom user config requires chat_id or user_id")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func resolveTarget(raw map[string]any) (string, error) {
|
||||
cfg, err := parseUserConfig(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if cfg.ChatID != "" {
|
||||
return "chat_id:" + cfg.ChatID, nil
|
||||
}
|
||||
return "user_id:" + cfg.UserID, nil
|
||||
}
|
||||
|
||||
func normalizeTarget(raw string) string {
|
||||
kind, id, ok := parseTarget(raw)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return kind + ":" + id
|
||||
}
|
||||
|
||||
func parseTarget(raw string) (kind string, id string, ok bool) {
|
||||
v := strings.TrimSpace(raw)
|
||||
if v == "" {
|
||||
return "", "", false
|
||||
}
|
||||
v = strings.TrimPrefix(v, "wecom:")
|
||||
v = strings.TrimPrefix(v, "workwx:")
|
||||
v = strings.TrimSpace(v)
|
||||
lv := strings.ToLower(v)
|
||||
switch {
|
||||
case strings.HasPrefix(lv, "chat_id:"):
|
||||
id = strings.TrimSpace(v[len("chat_id:"):])
|
||||
return "chat_id", id, id != ""
|
||||
case strings.HasPrefix(lv, "chat:"):
|
||||
id = strings.TrimSpace(v[len("chat:"):])
|
||||
return "chat_id", id, id != ""
|
||||
case strings.HasPrefix(lv, "group:"):
|
||||
id = strings.TrimSpace(v[len("group:"):])
|
||||
return "chat_id", id, id != ""
|
||||
case strings.HasPrefix(lv, "user_id:"):
|
||||
id = strings.TrimSpace(v[len("user_id:"):])
|
||||
return "user_id", id, id != ""
|
||||
case strings.HasPrefix(lv, "user:"):
|
||||
id = strings.TrimSpace(v[len("user:"):])
|
||||
return "user_id", id, id != ""
|
||||
default:
|
||||
return "chat_id", v, true
|
||||
}
|
||||
}
|
||||
|
||||
func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool {
|
||||
cfg, err := parseUserConfig(raw)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if value := strings.TrimSpace(criteria.Attribute("chat_id")); value != "" && value == cfg.ChatID {
|
||||
return true
|
||||
}
|
||||
if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID {
|
||||
return true
|
||||
}
|
||||
if criteria.SubjectID != "" && (criteria.SubjectID == cfg.ChatID || criteria.SubjectID == cfg.UserID) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildUserConfig(identity channel.Identity) map[string]any {
|
||||
out := map[string]any{}
|
||||
if v := strings.TrimSpace(identity.Attribute("chat_id")); v != "" {
|
||||
out["chat_id"] = v
|
||||
}
|
||||
if v := strings.TrimSpace(identity.Attribute("user_id")); v != "" {
|
||||
out["user_id"] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readInt(raw map[string]any, keys ...string) (int, bool) {
|
||||
for _, key := range keys {
|
||||
value, ok := raw[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v, true
|
||||
case int32:
|
||||
return int(v), true
|
||||
case int64:
|
||||
return int(v), true
|
||||
case float64:
|
||||
return int(v), true
|
||||
case float32:
|
||||
return int(v), true
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
parsed, err := strconv.Atoi(trimmed)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return parsed, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package wecom
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
cfg, err := parseConfig(map[string]any{
|
||||
"botId": "bot-1",
|
||||
"secret": "sec-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig error = %v", err)
|
||||
}
|
||||
if cfg.BotID != "bot-1" || cfg.Secret != "sec-1" {
|
||||
t.Fatalf("unexpected config: %+v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTarget(t *testing.T) {
|
||||
kind, id, ok := parseTarget("chat_id:abc")
|
||||
if !ok || kind != "chat_id" || id != "abc" {
|
||||
t.Fatalf("unexpected target parse result: ok=%v kind=%q id=%q", ok, kind, id)
|
||||
}
|
||||
kind, id, ok = parseTarget("user_id:zhangsan")
|
||||
if !ok || kind != "user_id" || id != "zhangsan" {
|
||||
t.Fatalf("unexpected target parse result: ok=%v kind=%q id=%q", ok, kind, id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("ciphertext is empty")
|
||||
}
|
||||
key, err := base64.StdEncoding.DecodeString(aesKeyBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode aes key failed: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("invalid aes key length: %d", len(key))
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("invalid ciphertext block size")
|
||||
}
|
||||
iv := key[:aes.BlockSize]
|
||||
out := make([]byte, len(ciphertext))
|
||||
cipher.NewCBCDecrypter(block, iv).CryptBlocks(out, ciphertext)
|
||||
plain, err := pkcs7Unpad(out, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
func pkcs7Unpad(data []byte, maxPad int) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, fmt.Errorf("pkcs7 payload is empty")
|
||||
}
|
||||
pad := int(data[len(data)-1])
|
||||
if pad <= 0 || pad > maxPad || pad > len(data) {
|
||||
return nil, fmt.Errorf("invalid pkcs7 padding length: %d", pad)
|
||||
}
|
||||
padding := bytes.Repeat([]byte{byte(pad)}, pad)
|
||||
if !bytes.Equal(data[len(data)-pad:], padding) {
|
||||
return nil, fmt.Errorf("invalid pkcs7 padding bytes")
|
||||
}
|
||||
return data[:len(data)-pad], nil
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecryptFileAES256CBC(t *testing.T) {
|
||||
key := []byte("0123456789abcdef0123456789abcdef")
|
||||
plain := []byte("hello-wecom-aibot")
|
||||
ciphertext := encryptPKCS7To32(t, key, plain)
|
||||
|
||||
out, err := DecryptFileAES256CBC(ciphertext, base64.StdEncoding.EncodeToString(key))
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptFileAES256CBC error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, plain) {
|
||||
t.Fatalf("plaintext mismatch: got=%q want=%q", string(out), string(plain))
|
||||
}
|
||||
}
|
||||
|
||||
func encryptPKCS7To32(t *testing.T, key []byte, plain []byte) []byte {
|
||||
t.Helper()
|
||||
padded := pkcs7PadTo32(plain)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
iv := key[:aes.BlockSize]
|
||||
out := make([]byte, len(padded))
|
||||
cipher.NewCBCEncrypter(block, iv).CryptBlocks(out, padded)
|
||||
return out
|
||||
}
|
||||
|
||||
func pkcs7PadTo32(data []byte) []byte {
|
||||
const blockSize = 32
|
||||
pad := blockSize - (len(data) % blockSize)
|
||||
if pad == 0 {
|
||||
pad = blockSize
|
||||
}
|
||||
out := make([]byte, len(data)+pad)
|
||||
copy(out, data)
|
||||
for i := len(data); i < len(out); i++ {
|
||||
out[i] = byte(pad)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HTTPClientOptions struct {
|
||||
Logger *slog.Logger
|
||||
Client *http.Client
|
||||
Transport *http.Transport
|
||||
Timeout time.Duration
|
||||
MaxIdleConns int
|
||||
MaxIdleConnsPerHost int
|
||||
IdleConnTimeout time.Duration
|
||||
TLSHandshakeTimeout time.Duration
|
||||
ResponseHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
type HTTPClient struct {
|
||||
client *http.Client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
type DownloadedFile struct {
|
||||
Data []byte
|
||||
FileName string
|
||||
ContentType string
|
||||
}
|
||||
|
||||
func NewHTTPClient(opts HTTPClientOptions) *HTTPClient {
|
||||
if opts.Logger == nil {
|
||||
opts.Logger = slog.Default()
|
||||
}
|
||||
if opts.Timeout <= 0 {
|
||||
opts.Timeout = 20 * time.Second
|
||||
}
|
||||
if opts.IdleConnTimeout <= 0 {
|
||||
opts.IdleConnTimeout = 90 * time.Second
|
||||
}
|
||||
if opts.TLSHandshakeTimeout <= 0 {
|
||||
opts.TLSHandshakeTimeout = 10 * time.Second
|
||||
}
|
||||
if opts.ResponseHeaderTimeout <= 0 {
|
||||
opts.ResponseHeaderTimeout = 15 * time.Second
|
||||
}
|
||||
if opts.MaxIdleConns <= 0 {
|
||||
opts.MaxIdleConns = 100
|
||||
}
|
||||
if opts.MaxIdleConnsPerHost <= 0 {
|
||||
opts.MaxIdleConnsPerHost = 10
|
||||
}
|
||||
client := opts.Client
|
||||
if client == nil {
|
||||
transport := opts.Transport
|
||||
if transport == nil {
|
||||
transport = &http.Transport{
|
||||
MaxIdleConns: opts.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: opts.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: opts.IdleConnTimeout,
|
||||
TLSHandshakeTimeout: opts.TLSHandshakeTimeout,
|
||||
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
|
||||
}
|
||||
}
|
||||
client = &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: opts.Timeout,
|
||||
}
|
||||
}
|
||||
return &HTTPClient{
|
||||
client: client,
|
||||
logger: opts.Logger.With(slog.String("component", "wecom_http_client")),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTPClient) DownloadFile(ctx context.Context, rawURL string) (DownloadedFile, error) {
|
||||
u := strings.TrimSpace(rawURL)
|
||||
if u == "" {
|
||||
return DownloadedFile{}, fmt.Errorf("download url is required")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return DownloadedFile{}, err
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return DownloadedFile{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return DownloadedFile{}, fmt.Errorf("download failed with status %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return DownloadedFile{}, err
|
||||
}
|
||||
fileName := extractFilename(resp.Header.Get("Content-Disposition"), u)
|
||||
return DownloadedFile{
|
||||
Data: data,
|
||||
FileName: fileName,
|
||||
ContentType: strings.TrimSpace(resp.Header.Get("Content-Type")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *HTTPClient) DownloadAndDecryptFile(ctx context.Context, rawURL string, aesKey string) (DownloadedFile, error) {
|
||||
file, err := c.DownloadFile(ctx, rawURL)
|
||||
if err != nil {
|
||||
return DownloadedFile{}, err
|
||||
}
|
||||
plain, err := DecryptFileAES256CBC(file.Data, aesKey)
|
||||
if err != nil {
|
||||
return DownloadedFile{}, err
|
||||
}
|
||||
file.Data = plain
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func extractFilename(contentDisposition, rawURL string) string {
|
||||
cd := strings.TrimSpace(contentDisposition)
|
||||
if cd != "" {
|
||||
_, params, err := mime.ParseMediaType(cd)
|
||||
if err == nil {
|
||||
if v := strings.TrimSpace(params["filename*"]); v != "" {
|
||||
if decoded := decodeRFC5987Filename(v); decoded != "" {
|
||||
return decoded
|
||||
}
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(params["filename"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base := strings.TrimSpace(path.Base(parsed.Path))
|
||||
if base == "." || base == "/" {
|
||||
return ""
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func decodeRFC5987Filename(value string) string {
|
||||
parts := strings.SplitN(strings.TrimSpace(value), "''", 2)
|
||||
encoded := strings.TrimSpace(value)
|
||||
if len(parts) == 2 {
|
||||
encoded = strings.TrimSpace(parts[1])
|
||||
}
|
||||
if encoded == "" {
|
||||
return ""
|
||||
}
|
||||
decoded, err := url.QueryUnescape(encoded)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(decoded)
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDownloadFile_ParsesFilename(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename*=UTF-8''hello%20wecom.txt")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewHTTPClient(HTTPClientOptions{})
|
||||
file, err := client.DownloadFile(context.Background(), srv.URL+"/download.bin")
|
||||
if err != nil {
|
||||
t.Fatalf("DownloadFile error = %v", err)
|
||||
}
|
||||
if file.FileName != "hello wecom.txt" {
|
||||
t.Fatalf("unexpected filename: got=%q", file.FileName)
|
||||
}
|
||||
if string(file.Data) != "ok" {
|
||||
t.Fatalf("unexpected payload: got=%q", string(file.Data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFilename_FallbackPath(t *testing.T) {
|
||||
got := extractFilename("", "https://example.com/files/a.png")
|
||||
if got != "a.png" {
|
||||
t.Fatalf("unexpected filename fallback: got=%q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,338 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
func (a *WeComAdapter) handleFrame(ctx context.Context, cfg channel.ChannelConfig, frame WSFrame, handler channel.InboundHandler) error {
|
||||
switch frame.Cmd {
|
||||
case WSCmdMsgCallback:
|
||||
if handler == nil {
|
||||
return nil
|
||||
}
|
||||
var body MessageCallbackBody
|
||||
if err := frame.DecodeBody(&body); err != nil {
|
||||
return err
|
||||
}
|
||||
msg, ok := buildInboundMessage(body, frame.Headers.ReqID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
a.rememberCallback(body.MsgID, frame.Headers.ReqID, body.ResponseURL, body.ChatID, body.From.UserID)
|
||||
return handler(ctx, cfg, msg)
|
||||
case WSCmdEventCallback:
|
||||
if handler == nil {
|
||||
return nil
|
||||
}
|
||||
var body EventCallbackBody
|
||||
if err := frame.DecodeBody(&body); err != nil {
|
||||
return err
|
||||
}
|
||||
msg, ok := buildInboundEventMessage(body, frame.Headers.ReqID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
a.rememberCallback(body.MsgID, frame.Headers.ReqID, body.ResponseURL, body.ChatID, body.From.UserID)
|
||||
return handler(ctx, cfg, msg)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildInboundMessage(body MessageCallbackBody, reqID string) (channel.InboundMessage, bool) {
|
||||
text, attachments := extractBodyContent(body)
|
||||
if strings.TrimSpace(text) == "" && len(attachments) == 0 {
|
||||
return channel.InboundMessage{}, false
|
||||
}
|
||||
target := resolveDeliveryTarget(body.ChatID, body.From.UserID)
|
||||
if target == "" {
|
||||
return channel.InboundMessage{}, false
|
||||
}
|
||||
convType := normalizeConversationType(body.ChatType)
|
||||
convID := strings.TrimSpace(body.ChatID)
|
||||
if convID == "" {
|
||||
convID = strings.TrimSpace(body.From.UserID)
|
||||
}
|
||||
msg := channel.InboundMessage{
|
||||
Channel: Type,
|
||||
Message: channel.Message{
|
||||
ID: strings.TrimSpace(body.MsgID),
|
||||
Format: channel.MessageFormatPlain,
|
||||
Text: strings.TrimSpace(text),
|
||||
Attachments: attachments,
|
||||
Metadata: map[string]any{
|
||||
"response_url": strings.TrimSpace(body.ResponseURL),
|
||||
},
|
||||
},
|
||||
ReplyTarget: target,
|
||||
Sender: channel.Identity{
|
||||
SubjectID: strings.TrimSpace(body.From.UserID),
|
||||
Attributes: map[string]string{
|
||||
"user_id": strings.TrimSpace(body.From.UserID),
|
||||
"chat_id": strings.TrimSpace(body.ChatID),
|
||||
},
|
||||
},
|
||||
Conversation: channel.Conversation{
|
||||
ID: convID,
|
||||
Type: convType,
|
||||
},
|
||||
ReceivedAt: parseCreateTime(body.CreateTime),
|
||||
Source: "wecom",
|
||||
Metadata: map[string]any{
|
||||
"req_id": strings.TrimSpace(reqID),
|
||||
"chat_id": strings.TrimSpace(body.ChatID),
|
||||
"chat_type": strings.TrimSpace(body.ChatType),
|
||||
"response_url": strings.TrimSpace(body.ResponseURL),
|
||||
},
|
||||
}
|
||||
return msg, true
|
||||
}
|
||||
|
||||
func buildInboundEventMessage(body EventCallbackBody, reqID string) (channel.InboundMessage, bool) {
|
||||
eventType := normalizeEventType(body.Event.EventType, body.Event.EventType2)
|
||||
if eventType == "" {
|
||||
return channel.InboundMessage{}, false
|
||||
}
|
||||
target := resolveDeliveryTarget(body.ChatID, body.From.UserID)
|
||||
if target == "" {
|
||||
return channel.InboundMessage{}, false
|
||||
}
|
||||
convType := normalizeConversationType(body.ChatType)
|
||||
convID := strings.TrimSpace(body.ChatID)
|
||||
if convID == "" {
|
||||
convID = strings.TrimSpace(body.From.UserID)
|
||||
}
|
||||
return channel.InboundMessage{
|
||||
Channel: Type,
|
||||
Message: channel.Message{
|
||||
ID: strings.TrimSpace(body.MsgID),
|
||||
Format: channel.MessageFormatPlain,
|
||||
Text: eventType,
|
||||
Metadata: map[string]any{
|
||||
"event_type": eventType,
|
||||
"task_id": strings.TrimSpace(body.Event.TaskID),
|
||||
"event_key": strings.TrimSpace(body.Event.EventKey),
|
||||
"event_code": strings.TrimSpace(body.Event.Code),
|
||||
"event_reason": strings.TrimSpace(body.Event.Reason),
|
||||
"task_status": strings.TrimSpace(body.Task.TaskStatus),
|
||||
"response_url": strings.TrimSpace(body.ResponseURL),
|
||||
},
|
||||
},
|
||||
ReplyTarget: target,
|
||||
Sender: channel.Identity{
|
||||
SubjectID: strings.TrimSpace(body.From.UserID),
|
||||
Attributes: map[string]string{
|
||||
"user_id": strings.TrimSpace(body.From.UserID),
|
||||
"chat_id": strings.TrimSpace(body.ChatID),
|
||||
},
|
||||
},
|
||||
Conversation: channel.Conversation{
|
||||
ID: convID,
|
||||
Type: convType,
|
||||
},
|
||||
ReceivedAt: parseCreateTime(body.CreateTime),
|
||||
Source: "wecom",
|
||||
Metadata: map[string]any{
|
||||
"is_event": true,
|
||||
"event_type": eventType,
|
||||
"event_key": strings.TrimSpace(body.Event.EventKey),
|
||||
"event_code": strings.TrimSpace(body.Event.Code),
|
||||
"event_reason": strings.TrimSpace(body.Event.Reason),
|
||||
"req_id": strings.TrimSpace(reqID),
|
||||
"response_url": strings.TrimSpace(body.ResponseURL),
|
||||
},
|
||||
}, true
|
||||
}
|
||||
|
||||
func extractBodyContent(body MessageCallbackBody) (string, []channel.Attachment) {
|
||||
switch strings.ToLower(strings.TrimSpace(body.MsgType)) {
|
||||
case "text":
|
||||
if body.Text == nil {
|
||||
return "", nil
|
||||
}
|
||||
return strings.TrimSpace(body.Text.Content), nil
|
||||
case "markdown":
|
||||
if body.Markdown == nil {
|
||||
return "", nil
|
||||
}
|
||||
return strings.TrimSpace(body.Markdown.Content), nil
|
||||
case "voice":
|
||||
if body.Voice == nil {
|
||||
return "", nil
|
||||
}
|
||||
return strings.TrimSpace(body.Voice.Content), nil
|
||||
case "video":
|
||||
if body.Video == nil {
|
||||
return "", nil
|
||||
}
|
||||
att := channel.Attachment{
|
||||
Type: channel.AttachmentVideo,
|
||||
URL: strings.TrimSpace(body.Video.URL),
|
||||
SourcePlatform: Type.String(),
|
||||
Metadata: map[string]any{
|
||||
"aeskey": strings.TrimSpace(body.Video.AESKey),
|
||||
},
|
||||
}
|
||||
return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)}
|
||||
case "image":
|
||||
if body.Image == nil {
|
||||
return "", nil
|
||||
}
|
||||
att := channel.Attachment{
|
||||
Type: channel.AttachmentImage,
|
||||
URL: strings.TrimSpace(body.Image.URL),
|
||||
SourcePlatform: Type.String(),
|
||||
Metadata: map[string]any{
|
||||
"aeskey": strings.TrimSpace(body.Image.AESKey),
|
||||
},
|
||||
}
|
||||
return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)}
|
||||
case "file":
|
||||
if body.File == nil {
|
||||
return "", nil
|
||||
}
|
||||
att := channel.Attachment{
|
||||
Type: channel.AttachmentFile,
|
||||
URL: strings.TrimSpace(body.File.URL),
|
||||
SourcePlatform: Type.String(),
|
||||
Name: strings.TrimSpace(body.File.FileName),
|
||||
Metadata: map[string]any{
|
||||
"aeskey": strings.TrimSpace(body.File.AESKey),
|
||||
},
|
||||
}
|
||||
return "", []channel.Attachment{channel.NormalizeInboundChannelAttachment(att)}
|
||||
case "mixed":
|
||||
var textParts []string
|
||||
attachments := make([]channel.Attachment, 0)
|
||||
for _, item := range body.Mixed {
|
||||
switch strings.ToLower(strings.TrimSpace(item.MsgType)) {
|
||||
case "text":
|
||||
if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" {
|
||||
textParts = append(textParts, strings.TrimSpace(item.Text.Content))
|
||||
}
|
||||
case "markdown":
|
||||
if item.Markdown != nil && strings.TrimSpace(item.Markdown.Content) != "" {
|
||||
textParts = append(textParts, strings.TrimSpace(item.Markdown.Content))
|
||||
}
|
||||
case "image":
|
||||
if item.Image == nil {
|
||||
continue
|
||||
}
|
||||
att := channel.Attachment{
|
||||
Type: channel.AttachmentImage,
|
||||
URL: strings.TrimSpace(item.Image.URL),
|
||||
SourcePlatform: Type.String(),
|
||||
Metadata: map[string]any{
|
||||
"aeskey": strings.TrimSpace(item.Image.AESKey),
|
||||
},
|
||||
}
|
||||
attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att))
|
||||
case "file":
|
||||
if item.File == nil {
|
||||
continue
|
||||
}
|
||||
att := channel.Attachment{
|
||||
Type: channel.AttachmentFile,
|
||||
URL: strings.TrimSpace(item.File.URL),
|
||||
SourcePlatform: Type.String(),
|
||||
Name: strings.TrimSpace(item.File.FileName),
|
||||
Metadata: map[string]any{
|
||||
"aeskey": strings.TrimSpace(item.File.AESKey),
|
||||
},
|
||||
}
|
||||
attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att))
|
||||
case "video":
|
||||
if item.Video == nil {
|
||||
continue
|
||||
}
|
||||
att := channel.Attachment{
|
||||
Type: channel.AttachmentVideo,
|
||||
URL: strings.TrimSpace(item.Video.URL),
|
||||
SourcePlatform: Type.String(),
|
||||
Metadata: map[string]any{
|
||||
"aeskey": strings.TrimSpace(item.Video.AESKey),
|
||||
},
|
||||
}
|
||||
attachments = append(attachments, channel.NormalizeInboundChannelAttachment(att))
|
||||
case "voice":
|
||||
if item.Voice != nil && strings.TrimSpace(item.Voice.Content) != "" {
|
||||
textParts = append(textParts, strings.TrimSpace(item.Voice.Content))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(textParts, "\n"), attachments
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func resolveDeliveryTarget(chatID, userID string) string {
|
||||
if v := strings.TrimSpace(chatID); v != "" {
|
||||
return "chat_id:" + v
|
||||
}
|
||||
if v := strings.TrimSpace(userID); v != "" {
|
||||
return "user_id:" + v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeConversationType(chatType string) string {
|
||||
if strings.EqualFold(strings.TrimSpace(chatType), "group") {
|
||||
return "group"
|
||||
}
|
||||
return "private"
|
||||
}
|
||||
|
||||
func parseCreateTime(ts int64) time.Time {
|
||||
if t := unixMilliseconds(ts); !t.IsZero() {
|
||||
return t
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) rememberCallback(msgID, reqID, responseURL, chatID, userID string) {
|
||||
if a == nil || a.cache == nil {
|
||||
return
|
||||
}
|
||||
msgID = strings.TrimSpace(msgID)
|
||||
if msgID == "" {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(reqID) == "" {
|
||||
return
|
||||
}
|
||||
a.cache.Put(msgID, callbackContext{
|
||||
ReqID: strings.TrimSpace(reqID),
|
||||
ResponseURL: strings.TrimSpace(responseURL),
|
||||
ChatID: strings.TrimSpace(chatID),
|
||||
UserID: strings.TrimSpace(userID),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeEventType(values ...string) string {
|
||||
candidates := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
v = strings.TrimSpace(strings.ToLower(v))
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
v = strings.ReplaceAll(v, "-", "_")
|
||||
v = strings.ReplaceAll(v, " ", "_")
|
||||
candidates = append(candidates, v)
|
||||
}
|
||||
for _, v := range candidates {
|
||||
switch v {
|
||||
case "enter_chat", "template_card_event", "feedback_event":
|
||||
return v
|
||||
}
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
return candidates[0]
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package wecom
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildInboundMessage_Text(t *testing.T) {
|
||||
msg, ok := buildInboundMessage(MessageCallbackBody{
|
||||
MsgID: "m1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: "group",
|
||||
From: CallbackFrom{
|
||||
UserID: "u1",
|
||||
},
|
||||
MsgType: "text",
|
||||
Text: &MessageText{
|
||||
Content: "hello",
|
||||
},
|
||||
}, "req-1")
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message")
|
||||
}
|
||||
if msg.Message.Text != "hello" {
|
||||
t.Fatalf("unexpected text: %q", msg.Message.Text)
|
||||
}
|
||||
if msg.ReplyTarget != "chat_id:chat-1" {
|
||||
t.Fatalf("unexpected target: %q", msg.ReplyTarget)
|
||||
}
|
||||
if msg.Metadata["req_id"] != "req-1" {
|
||||
t.Fatalf("unexpected req_id: %v", msg.Metadata["req_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildInboundMessage_Markdown(t *testing.T) {
|
||||
msg, ok := buildInboundMessage(MessageCallbackBody{
|
||||
MsgID: "m2",
|
||||
ChatID: "chat-2",
|
||||
ChatType: "private",
|
||||
From: CallbackFrom{UserID: "u2"},
|
||||
MsgType: "markdown",
|
||||
Markdown: &MessageText{
|
||||
Content: "**hello**",
|
||||
},
|
||||
}, "req-2")
|
||||
if !ok {
|
||||
t.Fatal("expected inbound markdown message")
|
||||
}
|
||||
if msg.Message.Text != "**hello**" {
|
||||
t.Fatalf("unexpected markdown text: %q", msg.Message.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildInboundMessage_Video(t *testing.T) {
|
||||
msg, ok := buildInboundMessage(MessageCallbackBody{
|
||||
MsgID: "m3",
|
||||
ChatID: "chat-3",
|
||||
ChatType: "group",
|
||||
From: CallbackFrom{UserID: "u3"},
|
||||
MsgType: "video",
|
||||
Video: &MessageVideo{
|
||||
URL: "https://example.com/v.mp4",
|
||||
AESKey: "k",
|
||||
},
|
||||
}, "req-3")
|
||||
if !ok {
|
||||
t.Fatal("expected inbound video message")
|
||||
}
|
||||
if len(msg.Message.Attachments) != 1 {
|
||||
t.Fatalf("expected one attachment, got %d", len(msg.Message.Attachments))
|
||||
}
|
||||
if msg.Message.Attachments[0].Type != "video" {
|
||||
t.Fatalf("unexpected attachment type: %s", msg.Message.Attachments[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildInboundEventMessage_NormalizeEventType(t *testing.T) {
|
||||
msg, ok := buildInboundEventMessage(EventCallbackBody{
|
||||
MsgID: "e1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: "group",
|
||||
From: CallbackFrom{UserID: "u1"},
|
||||
CreateTime: 1,
|
||||
MsgType: "event",
|
||||
Event: EventPayload{
|
||||
EventType2: "template-card-event",
|
||||
EventKey: "btn_1",
|
||||
TaskID: "task_1",
|
||||
},
|
||||
Task: EventTask{
|
||||
TaskStatus: "done",
|
||||
},
|
||||
}, "req-e1")
|
||||
if !ok {
|
||||
t.Fatal("expected event message")
|
||||
}
|
||||
if msg.Message.Text != "template_card_event" {
|
||||
t.Fatalf("unexpected normalized event text: %q", msg.Message.Text)
|
||||
}
|
||||
if msg.Message.Metadata["event_key"] != "btn_1" {
|
||||
t.Fatalf("unexpected event key: %v", msg.Message.Metadata["event_key"])
|
||||
}
|
||||
if msg.Message.Metadata["task_status"] != "done" {
|
||||
t.Fatalf("unexpected task status: %v", msg.Message.Metadata["task_status"])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
func (a *WeComAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) {
|
||||
_ = cfg
|
||||
if a.http == nil {
|
||||
return channel.AttachmentPayload{}, fmt.Errorf("wecom http client not configured")
|
||||
}
|
||||
url := strings.TrimSpace(attachment.URL)
|
||||
if url == "" {
|
||||
return channel.AttachmentPayload{}, fmt.Errorf("wecom attachment url is required")
|
||||
}
|
||||
aesKey := ""
|
||||
if attachment.Metadata != nil {
|
||||
if value, ok := attachment.Metadata["aeskey"].(string); ok {
|
||||
aesKey = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
var file DownloadedFile
|
||||
var err error
|
||||
if aesKey != "" {
|
||||
file, err = a.http.DownloadAndDecryptFile(ctx, url, aesKey)
|
||||
} else {
|
||||
file, err = a.http.DownloadFile(ctx, url)
|
||||
}
|
||||
if err != nil {
|
||||
return channel.AttachmentPayload{}, err
|
||||
}
|
||||
return channel.AttachmentPayload{
|
||||
Reader: io.NopCloser(bytes.NewReader(file.Data)),
|
||||
Mime: strings.TrimSpace(file.ContentType),
|
||||
Name: strings.TrimSpace(file.FileName),
|
||||
Size: int64(len(file.Data)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type markdownPayload struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func buildSendPayload(msg channel.Message, targetID string) (any, string, string, error) {
|
||||
if strings.TrimSpace(targetID) == "" {
|
||||
return nil, "", "", fmt.Errorf("wecom target id is required")
|
||||
}
|
||||
reqID := NewReqID(WSCmdSendMessage)
|
||||
if card, ok := readTemplateCard(msg.Metadata); ok {
|
||||
return SendMessageTemplateCardBody{
|
||||
ChatID: targetID,
|
||||
MsgType: "template_card",
|
||||
TemplateCard: card,
|
||||
}, WSCmdSendMessage, reqID, nil
|
||||
}
|
||||
|
||||
// aibot_send_msg currently supports markdown/template_card in official SDK.
|
||||
// Attachments should be sent through callback-reply path (aibot_respond_msg).
|
||||
if len(msg.Attachments) > 0 {
|
||||
return nil, "", "", fmt.Errorf("wecom proactive send does not support attachments; use reply flow")
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(msg.PlainText())
|
||||
if text == "" {
|
||||
return nil, "", "", fmt.Errorf("wecom outbound text is required")
|
||||
}
|
||||
return SendMessageMarkdownBody{
|
||||
ChatID: targetID,
|
||||
MsgType: "markdown",
|
||||
Markdown: markdownPayload{
|
||||
Content: text,
|
||||
},
|
||||
}, WSCmdSendMessage, reqID, nil
|
||||
}
|
||||
|
||||
func buildRespondPayload(msg channel.Message, replyReqID string) (any, string, string, error) {
|
||||
return buildRespondPayloadWithStream(msg, replyReqID, "", true)
|
||||
}
|
||||
|
||||
func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, streamID string, finish bool) (any, string, string, error) {
|
||||
reqID := strings.TrimSpace(replyReqID)
|
||||
if reqID == "" {
|
||||
return nil, "", "", fmt.Errorf("reply req_id is required")
|
||||
}
|
||||
if finish {
|
||||
if body, ok := buildWelcomePayload(msg); ok {
|
||||
return body, WSCmdRespondWelcome, reqID, nil
|
||||
}
|
||||
if body, ok := readUpdateTemplateCard(msg.Metadata); ok {
|
||||
return body, WSCmdRespondUpdate, reqID, nil
|
||||
}
|
||||
}
|
||||
text := strings.TrimSpace(msg.PlainText())
|
||||
if finish && text == "" && len(msg.Attachments) == 0 {
|
||||
return nil, "", "", fmt.Errorf("wecom reply payload is empty")
|
||||
}
|
||||
if !finish && text == "" {
|
||||
return nil, "", "", fmt.Errorf("wecom stream delta content is empty")
|
||||
}
|
||||
streamID = strings.TrimSpace(streamID)
|
||||
if streamID == "" {
|
||||
streamID = NewReqID("stream")
|
||||
}
|
||||
stream := StreamReplyBlock{
|
||||
ID: streamID,
|
||||
Finish: finish,
|
||||
Content: text,
|
||||
}
|
||||
if feedbackID := readFeedbackID(msg.Metadata); feedbackID != "" {
|
||||
stream.Feedback = &StreamReplyFeedback{ID: feedbackID}
|
||||
}
|
||||
if finish && len(msg.Attachments) > 0 {
|
||||
first := msg.Attachments[0]
|
||||
base64Data := extractBase64Content(first.Base64)
|
||||
if base64Data != "" {
|
||||
raw, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err == nil && len(raw) > 0 {
|
||||
stream.MsgItems = []StreamReplyItem{
|
||||
{
|
||||
MsgType: "image",
|
||||
Image: &StreamReplyImage{
|
||||
Base64: base64.StdEncoding.EncodeToString(raw),
|
||||
MD5: fmt.Sprintf("%x", md5.Sum(raw)),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if card, ok := readTemplateCard(msg.Metadata); ok {
|
||||
return StreamWithTemplateCardReplyBody{
|
||||
MsgType: "stream_with_template_card",
|
||||
Stream: stream,
|
||||
TemplateCard: card,
|
||||
}, WSCmdRespond, reqID, nil
|
||||
}
|
||||
return StreamReplyBody{
|
||||
MsgType: "stream",
|
||||
Stream: stream,
|
||||
}, WSCmdRespond, reqID, nil
|
||||
}
|
||||
|
||||
func buildWelcomePayload(msg channel.Message) (any, bool) {
|
||||
if !readBool(msg.Metadata, "wecom_welcome") {
|
||||
return nil, false
|
||||
}
|
||||
if card, ok := readTemplateCard(msg.Metadata); ok {
|
||||
return WelcomeTemplateCardReplyBody{
|
||||
MsgType: "template_card",
|
||||
TemplateCard: card,
|
||||
}, true
|
||||
}
|
||||
text := strings.TrimSpace(msg.PlainText())
|
||||
if text == "" {
|
||||
return nil, false
|
||||
}
|
||||
return WelcomeTextReplyBody{
|
||||
MsgType: "text",
|
||||
Text: welcomeTextBody{
|
||||
Content: text,
|
||||
},
|
||||
}, true
|
||||
}
|
||||
|
||||
func readTemplateCard(metadata map[string]any) (map[string]any, bool) {
|
||||
if metadata == nil {
|
||||
return nil, false
|
||||
}
|
||||
raw, ok := metadata["wecom_template_card"]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
card, ok := raw.(map[string]any)
|
||||
if !ok || len(card) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return card, true
|
||||
}
|
||||
|
||||
func readUpdateTemplateCard(metadata map[string]any) (UpdateTemplateCardBody, bool) {
|
||||
if metadata == nil {
|
||||
return UpdateTemplateCardBody{}, false
|
||||
}
|
||||
raw, ok := metadata["wecom_update_template_card"]
|
||||
if !ok {
|
||||
return UpdateTemplateCardBody{}, false
|
||||
}
|
||||
card, ok := raw.(map[string]any)
|
||||
if !ok || len(card) == 0 {
|
||||
return UpdateTemplateCardBody{}, false
|
||||
}
|
||||
body := UpdateTemplateCardBody{
|
||||
ResponseType: "update_template_card",
|
||||
TemplateCard: card,
|
||||
}
|
||||
if userIDs := readStringSlice(metadata["wecom_update_userids"]); len(userIDs) > 0 {
|
||||
body.UserIDs = userIDs
|
||||
}
|
||||
return body, true
|
||||
}
|
||||
|
||||
func readStringSlice(raw any) []string {
|
||||
switch v := raw.(type) {
|
||||
case []string:
|
||||
out := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s := strings.TrimSpace(item); s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok && strings.TrimSpace(s) != "" {
|
||||
out = append(out, strings.TrimSpace(s))
|
||||
}
|
||||
}
|
||||
return out
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(v, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if s := strings.TrimSpace(part); s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func readFeedbackID(metadata map[string]any) string {
|
||||
if metadata == nil {
|
||||
return ""
|
||||
}
|
||||
raw, ok := metadata["wecom_feedback_id"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if v, ok := raw.(string); ok {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func readBool(metadata map[string]any, key string) bool {
|
||||
if metadata == nil || strings.TrimSpace(key) == "" {
|
||||
return false
|
||||
}
|
||||
raw, ok := metadata[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
v, ok := raw.(bool)
|
||||
return ok && v
|
||||
}
|
||||
|
||||
func extractBase64Content(v string) string {
|
||||
value := strings.TrimSpace(v)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
if idx := strings.Index(value, ","); idx > 0 && strings.Contains(strings.ToLower(value[:idx]), "base64") {
|
||||
return strings.TrimSpace(value[idx+1:])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) getClient(botID string) *WSClient {
|
||||
key := strings.TrimSpace(botID)
|
||||
if key == "" {
|
||||
return nil
|
||||
}
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return a.clients[key]
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) lookupCallbackContext(reply *channel.ReplyRef) (callbackContext, bool) {
|
||||
if a == nil || a.cache == nil || reply == nil {
|
||||
return callbackContext{}, false
|
||||
}
|
||||
messageID := strings.TrimSpace(reply.MessageID)
|
||||
if messageID == "" {
|
||||
return callbackContext{}, false
|
||||
}
|
||||
return a.cache.Get(messageID)
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) ensureHTTPClient() {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.http == nil {
|
||||
a.http = NewHTTPClient(HTTPClientOptions{Logger: a.logger})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
func TestBuildSendPayload_Text(t *testing.T) {
|
||||
payload, cmd, reqID, err := buildSendPayload(channel.Message{
|
||||
Format: channel.MessageFormatPlain,
|
||||
Text: "hello",
|
||||
}, "chat_1")
|
||||
if err != nil {
|
||||
t.Fatalf("buildSendPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdSendMessage {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
if reqID == "" {
|
||||
t.Fatal("reqID should not be empty")
|
||||
}
|
||||
p, ok := payload.(SendMessageMarkdownBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.Markdown.Content != "hello" {
|
||||
t.Fatalf("unexpected payload content: %q", p.Markdown.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSendPayload_AttachmentNotSupported(t *testing.T) {
|
||||
_, _, _, err := buildSendPayload(channel.Message{
|
||||
Attachments: []channel.Attachment{
|
||||
{Type: channel.AttachmentImage, Base64: "aGVsbG8="},
|
||||
},
|
||||
}, "chat_1")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for proactive attachment send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayload_Stream(t *testing.T) {
|
||||
payload, cmd, reqID, err := buildRespondPayload(channel.Message{
|
||||
Format: channel.MessageFormatMarkdown,
|
||||
Text: "world",
|
||||
}, "req_abc")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRespondPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdRespond {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
if reqID != "req_abc" {
|
||||
t.Fatalf("unexpected req id: %q", reqID)
|
||||
}
|
||||
p, ok := payload.(StreamReplyBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.MsgType != "stream" || p.Stream.Content != "world" || !p.Stream.Finish {
|
||||
t.Fatalf("unexpected stream payload: %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayload_StreamWithFeedback(t *testing.T) {
|
||||
payload, cmd, _, err := buildRespondPayload(channel.Message{
|
||||
Text: "world",
|
||||
Metadata: map[string]any{
|
||||
"wecom_feedback_id": "feedback_1",
|
||||
},
|
||||
}, "req_abc")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRespondPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdRespond {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
p, ok := payload.(StreamReplyBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.Stream.Feedback == nil || p.Stream.Feedback.ID != "feedback_1" {
|
||||
t.Fatalf("unexpected feedback: %+v", p.Stream.Feedback)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayloadWithStream_Delta(t *testing.T) {
|
||||
payload, cmd, reqID, err := buildRespondPayloadWithStream(channel.Message{
|
||||
Text: "delta",
|
||||
}, "req_abc", "stream_1", false)
|
||||
if err != nil {
|
||||
t.Fatalf("buildRespondPayloadWithStream error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdRespond || reqID != "req_abc" {
|
||||
t.Fatalf("unexpected cmd/reqid: %q %q", cmd, reqID)
|
||||
}
|
||||
p, ok := payload.(StreamReplyBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.Stream.ID != "stream_1" || p.Stream.Finish {
|
||||
t.Fatalf("unexpected stream payload: %+v", p.Stream)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayloadWithStream_EmptyDeltaRejected(t *testing.T) {
|
||||
_, _, _, err := buildRespondPayloadWithStream(channel.Message{}, "req_abc", "stream_1", false)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty delta content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSendPayload_TemplateCard(t *testing.T) {
|
||||
payload, cmd, _, err := buildSendPayload(channel.Message{
|
||||
Metadata: map[string]any{
|
||||
"wecom_template_card": map[string]any{
|
||||
"card_type": "text_notice",
|
||||
},
|
||||
},
|
||||
}, "chat_1")
|
||||
if err != nil {
|
||||
t.Fatalf("buildSendPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdSendMessage {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
p, ok := payload.(SendMessageTemplateCardBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.MsgType != "template_card" {
|
||||
t.Fatalf("unexpected msg type: %q", p.MsgType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayload_StreamWithTemplateCard(t *testing.T) {
|
||||
payload, cmd, _, err := buildRespondPayload(channel.Message{
|
||||
Text: "x",
|
||||
Metadata: map[string]any{
|
||||
"wecom_template_card": map[string]any{
|
||||
"card_type": "text_notice",
|
||||
},
|
||||
},
|
||||
}, "req_abc")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRespondPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdRespond {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
p, ok := payload.(StreamWithTemplateCardReplyBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.MsgType != "stream_with_template_card" {
|
||||
t.Fatalf("unexpected msg type: %q", p.MsgType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayload_UpdateTemplateCard(t *testing.T) {
|
||||
payload, cmd, reqID, err := buildRespondPayload(channel.Message{
|
||||
Metadata: map[string]any{
|
||||
"wecom_update_template_card": map[string]any{
|
||||
"card_type": "text_notice",
|
||||
"task_id": "task-1",
|
||||
},
|
||||
"wecom_update_userids": []any{"u1", "u2"},
|
||||
},
|
||||
}, "req_abc")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRespondPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdRespondUpdate {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
if reqID != "req_abc" {
|
||||
t.Fatalf("unexpected req id: %q", reqID)
|
||||
}
|
||||
p, ok := payload.(UpdateTemplateCardBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.ResponseType != "update_template_card" {
|
||||
t.Fatalf("unexpected response type: %q", p.ResponseType)
|
||||
}
|
||||
if len(p.UserIDs) != 2 || p.UserIDs[0] != "u1" || p.UserIDs[1] != "u2" {
|
||||
t.Fatalf("unexpected userids: %#v", p.UserIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayload_WelcomeText(t *testing.T) {
|
||||
payload, cmd, reqID, err := buildRespondPayload(channel.Message{
|
||||
Text: "welcome",
|
||||
Metadata: map[string]any{
|
||||
"wecom_welcome": true,
|
||||
},
|
||||
}, "req_abc")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRespondPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdRespondWelcome {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
if reqID != "req_abc" {
|
||||
t.Fatalf("unexpected req id: %q", reqID)
|
||||
}
|
||||
p, ok := payload.(WelcomeTextReplyBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.MsgType != "text" || p.Text.Content != "welcome" {
|
||||
t.Fatalf("unexpected welcome payload: %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRespondPayload_WelcomeTemplateCard(t *testing.T) {
|
||||
payload, cmd, _, err := buildRespondPayload(channel.Message{
|
||||
Metadata: map[string]any{
|
||||
"wecom_welcome": true,
|
||||
"wecom_template_card": map[string]any{
|
||||
"card_type": "text_notice",
|
||||
},
|
||||
},
|
||||
}, "req_abc")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRespondPayload error = %v", err)
|
||||
}
|
||||
if cmd != WSCmdRespondWelcome {
|
||||
t.Fatalf("unexpected cmd: %q", cmd)
|
||||
}
|
||||
p, ok := payload.(WelcomeTemplateCardReplyBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected payload type: %T", payload)
|
||||
}
|
||||
if p.MsgType != "template_card" {
|
||||
t.Fatalf("unexpected msg type: %q", p.MsgType)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultWSURL = "wss://openws.work.weixin.qq.com"
|
||||
)
|
||||
|
||||
const (
|
||||
WSCmdSubscribe = "aibot_subscribe"
|
||||
WSCmdHeartbeat = "ping"
|
||||
WSCmdRespond = "aibot_respond_msg"
|
||||
WSCmdRespondWelcome = "aibot_respond_welcome_msg"
|
||||
WSCmdRespondUpdate = "aibot_respond_update_msg"
|
||||
WSCmdSendMessage = "aibot_send_msg"
|
||||
WSCmdMsgCallback = "aibot_msg_callback"
|
||||
WSCmdEventCallback = "aibot_event_callback"
|
||||
)
|
||||
|
||||
type WSHeaders struct {
|
||||
ReqID string `json:"req_id"`
|
||||
}
|
||||
|
||||
type WSFrame struct {
|
||||
Cmd string `json:"cmd,omitempty"`
|
||||
Headers WSHeaders `json:"headers"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
ErrCode int `json:"errcode,omitempty"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
func (f WSFrame) DecodeBody(dst any) error {
|
||||
if len(f.Body) == 0 {
|
||||
return fmt.Errorf("wecom frame body is empty")
|
||||
}
|
||||
if dst == nil {
|
||||
return fmt.Errorf("decode target is nil")
|
||||
}
|
||||
return json.Unmarshal(f.Body, dst)
|
||||
}
|
||||
|
||||
func BuildFrame(cmd, reqID string, body any) (WSFrame, error) {
|
||||
frame := WSFrame{
|
||||
Cmd: cmd,
|
||||
Headers: WSHeaders{
|
||||
ReqID: strings.TrimSpace(reqID),
|
||||
},
|
||||
}
|
||||
if frame.Headers.ReqID == "" {
|
||||
return WSFrame{}, fmt.Errorf("req_id is required")
|
||||
}
|
||||
if body == nil {
|
||||
return frame, nil
|
||||
}
|
||||
raw, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return WSFrame{}, err
|
||||
}
|
||||
frame.Body = raw
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func NewReqID(prefix string) string {
|
||||
p := strings.TrimSpace(prefix)
|
||||
if p == "" {
|
||||
return uuid.NewString()
|
||||
}
|
||||
return p + "_" + uuid.NewString()
|
||||
}
|
||||
|
||||
type AuthCredentials struct {
|
||||
BotID string
|
||||
Secret string
|
||||
}
|
||||
|
||||
func (c AuthCredentials) Validate() error {
|
||||
if strings.TrimSpace(c.BotID) == "" {
|
||||
return fmt.Errorf("wecom bot_id is required")
|
||||
}
|
||||
if strings.TrimSpace(c.Secret) == "" {
|
||||
return fmt.Errorf("wecom secret is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SubscribeBody struct {
|
||||
BotID string `json:"bot_id"`
|
||||
Secret string `json:"secret"`
|
||||
}
|
||||
|
||||
type CallbackFrom struct {
|
||||
UserID string `json:"userid,omitempty"`
|
||||
}
|
||||
|
||||
type MessageText struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type MessageImage struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
}
|
||||
|
||||
type MessageFile struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
}
|
||||
|
||||
type MessageVoice struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type MessageVideo struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
}
|
||||
|
||||
type MessageMixedItem struct {
|
||||
MsgType string `json:"msgtype,omitempty"`
|
||||
Text *MessageText `json:"text,omitempty"`
|
||||
Markdown *MessageText `json:"markdown,omitempty"`
|
||||
Image *MessageImage `json:"image,omitempty"`
|
||||
File *MessageFile `json:"file,omitempty"`
|
||||
Voice *MessageVoice `json:"voice,omitempty"`
|
||||
Video *MessageVideo `json:"video,omitempty"`
|
||||
}
|
||||
|
||||
type MessageQuote struct {
|
||||
MsgID string `json:"msgid,omitempty"`
|
||||
}
|
||||
|
||||
type MessageCallbackBody struct {
|
||||
MsgID string `json:"msgid,omitempty"`
|
||||
AIBotID string `json:"aibotid,omitempty"`
|
||||
ChatID string `json:"chatid,omitempty"`
|
||||
ChatType string `json:"chattype,omitempty"`
|
||||
From CallbackFrom `json:"from,omitempty"`
|
||||
CreateTime int64 `json:"create_time,omitempty"`
|
||||
MsgType string `json:"msgtype,omitempty"`
|
||||
ResponseURL string `json:"response_url,omitempty"`
|
||||
Text *MessageText `json:"text,omitempty"`
|
||||
Markdown *MessageText `json:"markdown,omitempty"`
|
||||
Image *MessageImage `json:"image,omitempty"`
|
||||
File *MessageFile `json:"file,omitempty"`
|
||||
Voice *MessageVoice `json:"voice,omitempty"`
|
||||
Video *MessageVideo `json:"video,omitempty"`
|
||||
Mixed []MessageMixedItem `json:"mixed,omitempty"`
|
||||
Quote *MessageQuote `json:"quote,omitempty"`
|
||||
}
|
||||
|
||||
type EventPayload struct {
|
||||
EventType string `json:"event_type,omitempty"`
|
||||
EventType2 string `json:"eventtype,omitempty"`
|
||||
EventKey string `json:"event_key,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
type EventTask struct {
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
}
|
||||
|
||||
type EventCallbackBody struct {
|
||||
MsgID string `json:"msgid,omitempty"`
|
||||
AIBotID string `json:"aibotid,omitempty"`
|
||||
ChatID string `json:"chatid,omitempty"`
|
||||
ChatType string `json:"chattype,omitempty"`
|
||||
From CallbackFrom `json:"from,omitempty"`
|
||||
CreateTime int64 `json:"create_time,omitempty"`
|
||||
MsgType string `json:"msgtype,omitempty"`
|
||||
ResponseURL string `json:"response_url,omitempty"`
|
||||
Event EventPayload `json:"event,omitempty"`
|
||||
Task EventTask `json:"task,omitempty"`
|
||||
}
|
||||
|
||||
type StreamReplyBody struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Stream StreamReplyBlock `json:"stream"`
|
||||
}
|
||||
|
||||
type StreamReplyBlock struct {
|
||||
ID string `json:"id"`
|
||||
Finish bool `json:"finish,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
MsgItems []StreamReplyItem `json:"msg_item,omitempty"`
|
||||
Feedback *StreamReplyFeedback `json:"feedback,omitempty"`
|
||||
}
|
||||
|
||||
type StreamReplyItem struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Image *StreamReplyImage `json:"image,omitempty"`
|
||||
}
|
||||
|
||||
type StreamReplyImage struct {
|
||||
Base64 string `json:"base64"`
|
||||
MD5 string `json:"md5"`
|
||||
}
|
||||
|
||||
type StreamReplyFeedback struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type SendMessageMarkdownBody struct {
|
||||
ChatID string `json:"chatid"`
|
||||
MsgType string `json:"msgtype"`
|
||||
Markdown markdownPayload `json:"markdown"`
|
||||
}
|
||||
|
||||
type SendMessageTemplateCardBody struct {
|
||||
ChatID string `json:"chatid"`
|
||||
MsgType string `json:"msgtype"`
|
||||
TemplateCard map[string]any `json:"template_card"`
|
||||
}
|
||||
|
||||
type StreamWithTemplateCardReplyBody struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Stream StreamReplyBlock `json:"stream"`
|
||||
TemplateCard map[string]any `json:"template_card"`
|
||||
}
|
||||
|
||||
type WelcomeTextReplyBody struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text welcomeTextBody `json:"text"`
|
||||
}
|
||||
|
||||
type welcomeTextBody struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type WelcomeTemplateCardReplyBody struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
TemplateCard map[string]any `json:"template_card"`
|
||||
}
|
||||
|
||||
type UpdateTemplateCardBody struct {
|
||||
ResponseType string `json:"response_type"`
|
||||
UserIDs []string `json:"userids,omitempty"`
|
||||
TemplateCard map[string]any `json:"template_card"`
|
||||
}
|
||||
|
||||
func unixMilliseconds(ts int64) time.Time {
|
||||
if ts <= 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.UnixMilli(ts)
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package wecom
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildFrameAndDecodeBody(t *testing.T) {
|
||||
type body struct {
|
||||
BotID string `json:"bot_id"`
|
||||
}
|
||||
frame, err := BuildFrame(WSCmdSubscribe, "req-1", body{BotID: "bot123"})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildFrame error = %v", err)
|
||||
}
|
||||
var decoded body
|
||||
if err := frame.DecodeBody(&decoded); err != nil {
|
||||
t.Fatalf("DecodeBody error = %v", err)
|
||||
}
|
||||
if decoded.BotID != "bot123" {
|
||||
t.Fatalf("unexpected decoded body: %+v", decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthCredentialsValidate(t *testing.T) {
|
||||
if err := (AuthCredentials{}).Validate(); err == nil {
|
||||
t.Fatal("expected validation error for empty credentials")
|
||||
}
|
||||
if err := (AuthCredentials{BotID: "id", Secret: "sec"}).Validate(); err != nil {
|
||||
t.Fatalf("unexpected validation error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
const Type channel.ChannelType = "wecom"
|
||||
|
||||
type wsClientFactory func(opts WSClientOptions) *WSClient
|
||||
|
||||
type WeComAdapter struct {
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
clients map[string]*WSClient
|
||||
http *HTTPClient
|
||||
cache *callbackContextCache
|
||||
|
||||
newWSClient wsClientFactory
|
||||
}
|
||||
|
||||
func NewWeComAdapter(log *slog.Logger) *WeComAdapter {
|
||||
if log == nil {
|
||||
log = slog.Default()
|
||||
}
|
||||
return &WeComAdapter{
|
||||
logger: log.With(slog.String("adapter", "wecom")),
|
||||
clients: make(map[string]*WSClient),
|
||||
http: NewHTTPClient(HTTPClientOptions{Logger: log}),
|
||||
cache: newCallbackContextCache(24 * time.Hour),
|
||||
newWSClient: func(opts WSClientOptions) *WSClient { return NewWSClient(opts) },
|
||||
}
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) Type() channel.ChannelType { return Type }
|
||||
|
||||
func (a *WeComAdapter) Descriptor() channel.Descriptor {
|
||||
return channel.Descriptor{
|
||||
Type: Type,
|
||||
DisplayName: "WeCom",
|
||||
Capabilities: channel.ChannelCapabilities{
|
||||
Text: true,
|
||||
Markdown: true,
|
||||
Attachments: true,
|
||||
Media: true,
|
||||
Reply: true,
|
||||
Streaming: true,
|
||||
BlockStreaming: true,
|
||||
ChatTypes: []string{"private", "group"},
|
||||
},
|
||||
ConfigSchema: channel.ConfigSchema{
|
||||
Version: 1,
|
||||
Fields: map[string]channel.FieldSchema{
|
||||
"botId": {Type: channel.FieldString, Required: true, Title: "Bot ID"},
|
||||
"secret": {Type: channel.FieldSecret, Required: true, Title: "Secret"},
|
||||
"wsUrl": {Type: channel.FieldString, Title: "WebSocket URL", Example: defaultWSURL},
|
||||
"heartbeatSeconds": {Type: channel.FieldNumber, Title: "Heartbeat Seconds"},
|
||||
"ackTimeoutSeconds": {Type: channel.FieldNumber, Title: "Ack Timeout Seconds"},
|
||||
"writeTimeoutSeconds": {Type: channel.FieldNumber, Title: "Write Timeout Seconds"},
|
||||
"readTimeoutSeconds": {Type: channel.FieldNumber, Title: "Read Timeout Seconds"},
|
||||
},
|
||||
},
|
||||
UserConfigSchema: channel.ConfigSchema{
|
||||
Version: 1,
|
||||
Fields: map[string]channel.FieldSchema{
|
||||
"chat_id": {Type: channel.FieldString},
|
||||
"user_id": {Type: channel.FieldString},
|
||||
},
|
||||
},
|
||||
TargetSpec: channel.TargetSpec{
|
||||
Format: "chat_id:xxx | user_id:xxx",
|
||||
Hints: []channel.TargetHint{
|
||||
{Label: "Chat ID", Example: "chat_id:wrk_abc"},
|
||||
{Label: "User ID", Example: "user_id:zhangsan"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) {
|
||||
return normalizeConfig(raw)
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) {
|
||||
return normalizeUserConfig(raw)
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) }
|
||||
|
||||
func (a *WeComAdapter) ResolveTarget(userConfig map[string]any) (string, error) {
|
||||
return resolveTarget(userConfig)
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool {
|
||||
return matchBinding(config, criteria)
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) BuildUserConfig(identity channel.Identity) map[string]any {
|
||||
return buildUserConfig(identity)
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) {
|
||||
_ = ctx
|
||||
cfg, err := parseConfig(credentials)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
externalID := strings.TrimSpace(cfg.BotID)
|
||||
identity := map[string]any{
|
||||
"bot_id": externalID,
|
||||
"aibot_id": externalID,
|
||||
}
|
||||
return identity, externalID, nil
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) {
|
||||
parsed, err := parseConfig(cfg.Credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := a.newWSClient(WSClientOptions{
|
||||
URL: parsed.WSURL,
|
||||
Logger: a.logger,
|
||||
HeartbeatInterval: time.Duration(secondsOrDefault(parsed.HeartbeatSeconds, 30)) * time.Second,
|
||||
AckTimeout: time.Duration(secondsOrDefault(parsed.AckTimeoutSeconds, 8)) * time.Second,
|
||||
WriteTimeout: time.Duration(secondsOrDefault(parsed.WriteTimeoutSeconds, 8)) * time.Second,
|
||||
ReadTimeout: time.Duration(secondsOrDefault(parsed.ReadTimeoutSeconds, 70)) * time.Second,
|
||||
ReconnectBaseDelay: 1 * time.Second,
|
||||
ReconnectMaxDelay: 30 * time.Second,
|
||||
})
|
||||
|
||||
key := strings.TrimSpace(parsed.BotID)
|
||||
a.mu.Lock()
|
||||
a.clients[key] = client
|
||||
a.mu.Unlock()
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err := client.Run(connCtx, AuthCredentials{
|
||||
BotID: parsed.BotID,
|
||||
Secret: parsed.Secret,
|
||||
}, func(frameCtx context.Context, frame WSFrame) error {
|
||||
return a.handleFrame(frameCtx, cfg, frame, handler)
|
||||
})
|
||||
if err != nil && connCtx.Err() == nil {
|
||||
a.logger.Error("wecom websocket stopped",
|
||||
slog.String("config_id", cfg.ID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
stop := func(context.Context) error {
|
||||
cancel()
|
||||
_ = client.Close()
|
||||
<-done
|
||||
a.mu.Lock()
|
||||
if current, ok := a.clients[key]; ok && current == client {
|
||||
delete(a.clients, key)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return channel.NewConnection(cfg, stop), nil
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error {
|
||||
targetKind, targetID, ok := parseTarget(msg.Target)
|
||||
if !ok {
|
||||
return fmt.Errorf("wecom target is required")
|
||||
}
|
||||
parsed, err := parseConfig(cfg.Credentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := a.getClient(parsed.BotID)
|
||||
if client == nil {
|
||||
return fmt.Errorf("wecom connection is not active")
|
||||
}
|
||||
if msg.Message.IsEmpty() {
|
||||
return fmt.Errorf("message is required")
|
||||
}
|
||||
var (
|
||||
payload any
|
||||
cmd string
|
||||
reqID string
|
||||
buildErr error
|
||||
)
|
||||
if ctxMeta, ok := a.lookupCallbackContext(msg.Message.Reply); ok {
|
||||
payload, cmd, reqID, buildErr = buildRespondPayload(msg.Message, ctxMeta.ReqID)
|
||||
} else {
|
||||
_ = targetKind
|
||||
payload, cmd, reqID, buildErr = buildSendPayload(msg.Message, targetID)
|
||||
}
|
||||
if buildErr != nil {
|
||||
return buildErr
|
||||
}
|
||||
ack, err := client.Reply(ctx, reqID, cmd, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ack.ErrCode != 0 {
|
||||
return fmt.Errorf("wecom send failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) {
|
||||
target = strings.TrimSpace(target)
|
||||
if target == "" {
|
||||
return nil, fmt.Errorf("wecom target is required")
|
||||
}
|
||||
reply := opts.Reply
|
||||
if reply == nil && strings.TrimSpace(opts.SourceMessageID) != "" {
|
||||
reply = &channel.ReplyRef{
|
||||
Target: target,
|
||||
MessageID: strings.TrimSpace(opts.SourceMessageID),
|
||||
}
|
||||
}
|
||||
return &wecomOutboundStream{
|
||||
adapter: a,
|
||||
cfg: cfg,
|
||||
target: target,
|
||||
reply: reply,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type wecomOutboundStream struct {
|
||||
adapter *WeComAdapter
|
||||
cfg channel.ChannelConfig
|
||||
target string
|
||||
reply *channel.ReplyRef
|
||||
|
||||
mu sync.Mutex
|
||||
closed atomic.Bool
|
||||
finalSent atomic.Bool
|
||||
textBuilder strings.Builder
|
||||
attachments []channel.Attachment
|
||||
final *channel.Message
|
||||
streamID string
|
||||
lastPreview string
|
||||
}
|
||||
|
||||
func (s *wecomOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error {
|
||||
if s.adapter == nil {
|
||||
return errors.New("wecom stream not configured")
|
||||
}
|
||||
if s.closed.Load() {
|
||||
return errors.New("wecom stream is closed")
|
||||
}
|
||||
if s.finalSent.Load() {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
switch event.Type {
|
||||
case channel.StreamEventStatus,
|
||||
channel.StreamEventPhaseStart,
|
||||
channel.StreamEventPhaseEnd,
|
||||
channel.StreamEventToolCallStart,
|
||||
channel.StreamEventToolCallEnd,
|
||||
channel.StreamEventAgentStart,
|
||||
channel.StreamEventAgentEnd,
|
||||
channel.StreamEventProcessingStarted,
|
||||
channel.StreamEventProcessingCompleted,
|
||||
channel.StreamEventProcessingFailed:
|
||||
return nil
|
||||
case channel.StreamEventDelta:
|
||||
if strings.TrimSpace(event.Delta) == "" || event.Phase == channel.StreamPhaseReasoning {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.textBuilder.WriteString(event.Delta)
|
||||
s.mu.Unlock()
|
||||
return s.pushPreview(ctx)
|
||||
case channel.StreamEventAttachment:
|
||||
if len(event.Attachments) == 0 {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.attachments = append(s.attachments, event.Attachments...)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
case channel.StreamEventFinal:
|
||||
if event.Final == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
final := event.Final.Message
|
||||
s.final = &final
|
||||
s.mu.Unlock()
|
||||
return s.flush(ctx)
|
||||
case channel.StreamEventError:
|
||||
text := strings.TrimSpace(event.Error)
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.final = &channel.Message{Format: channel.MessageFormatPlain, Text: "Error: " + text}
|
||||
s.mu.Unlock()
|
||||
return s.flush(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wecomOutboundStream) Close(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
s.closed.Store(true)
|
||||
if s.finalSent.Load() {
|
||||
return nil
|
||||
}
|
||||
return s.flush(ctx)
|
||||
}
|
||||
|
||||
func (s *wecomOutboundStream) flush(ctx context.Context) error {
|
||||
if s.finalSent.Load() {
|
||||
return nil
|
||||
}
|
||||
msg, streamID := s.snapshotMessage(true)
|
||||
if msg.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
if ctxMeta, ok := s.adapter.lookupCallbackContext(msg.Reply); ok {
|
||||
if err := s.adapter.sendRespondStream(ctx, s.cfg, msg, ctxMeta.ReqID, streamID, true); err != nil {
|
||||
return err
|
||||
}
|
||||
s.finalSent.Store(true)
|
||||
return nil
|
||||
}
|
||||
if err := s.adapter.Send(ctx, s.cfg, channel.OutboundMessage{
|
||||
Target: s.target,
|
||||
Message: msg,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
s.finalSent.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wecomOutboundStream) pushPreview(ctx context.Context) error {
|
||||
if s.finalSent.Load() {
|
||||
return nil
|
||||
}
|
||||
msg, streamID := s.snapshotMessage(false)
|
||||
text := strings.TrimSpace(msg.PlainText())
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.lastPreview == text {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
if ctxMeta, ok := s.adapter.lookupCallbackContext(msg.Reply); ok {
|
||||
if err := s.adapter.sendRespondStream(ctx, s.cfg, msg, ctxMeta.ReqID, streamID, false); err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.lastPreview = text
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wecomOutboundStream) snapshotMessage(includeAttachments bool) (channel.Message, string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
msg := channel.Message{}
|
||||
if s.final != nil {
|
||||
msg = *s.final
|
||||
}
|
||||
if strings.TrimSpace(msg.Text) == "" {
|
||||
msg.Text = strings.TrimSpace(s.textBuilder.String())
|
||||
}
|
||||
if includeAttachments && len(msg.Attachments) == 0 && len(s.attachments) > 0 {
|
||||
msg.Attachments = append(msg.Attachments, s.attachments...)
|
||||
}
|
||||
if msg.Reply == nil && s.reply != nil {
|
||||
msg.Reply = s.reply
|
||||
}
|
||||
if s.streamID == "" {
|
||||
s.streamID = NewReqID("stream")
|
||||
}
|
||||
return msg, s.streamID
|
||||
}
|
||||
|
||||
func (a *WeComAdapter) sendRespondStream(ctx context.Context, cfg channel.ChannelConfig, msg channel.Message, reqID string, streamID string, finish bool) error {
|
||||
parsed, err := parseConfig(cfg.Credentials)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := a.getClient(parsed.BotID)
|
||||
if client == nil {
|
||||
return fmt.Errorf("wecom connection is not active")
|
||||
}
|
||||
payload, cmd, ackReqID, err := buildRespondPayloadWithStream(msg, reqID, streamID, finish)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ack, err := client.Reply(ctx, ackReqID, cmd, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ack.ErrCode != 0 {
|
||||
return fmt.Errorf("wecom send failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func secondsOrDefault(value int, fallback int) int {
|
||||
if value > 0 {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
func TestDiscoverSelf(t *testing.T) {
|
||||
adapter := NewWeComAdapter(nil)
|
||||
identity, externalID, err := adapter.DiscoverSelf(context.Background(), map[string]any{
|
||||
"botId": "bot_123",
|
||||
"secret": "sec",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("DiscoverSelf error = %v", err)
|
||||
}
|
||||
if externalID != "bot_123" {
|
||||
t.Fatalf("unexpected external id: %q", externalID)
|
||||
}
|
||||
if identity["bot_id"] != "bot_123" {
|
||||
t.Fatalf("unexpected bot_id: %v", identity["bot_id"])
|
||||
}
|
||||
if identity["aibot_id"] != "bot_123" {
|
||||
t.Fatalf("unexpected aibot_id: %v", identity["aibot_id"])
|
||||
}
|
||||
if _, ok := identity["name"]; ok {
|
||||
t.Fatalf("unexpected name field: %v", identity["name"])
|
||||
}
|
||||
if _, ok := identity["display_name"]; ok {
|
||||
t.Fatalf("unexpected display_name field: %v", identity["display_name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenStream_FallbackReplyFromSourceMessageID(t *testing.T) {
|
||||
adapter := NewWeComAdapter(nil)
|
||||
stream, err := adapter.OpenStream(context.Background(), channel.ChannelConfig{}, "chat_id:chat_1", channel.StreamOptions{
|
||||
SourceMessageID: "msg_1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStream error = %v", err)
|
||||
}
|
||||
ws, ok := stream.(*wecomOutboundStream)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected stream type: %T", stream)
|
||||
}
|
||||
if ws.reply == nil || ws.reply.MessageID != "msg_1" || ws.reply.Target != "chat_id:chat_1" {
|
||||
t.Fatalf("unexpected reply fallback: %+v", ws.reply)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WSClientOptions struct {
|
||||
URL string
|
||||
Dialer *websocket.Dialer
|
||||
Logger *slog.Logger
|
||||
AckTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
HeartbeatInterval time.Duration
|
||||
ReconnectBaseDelay time.Duration
|
||||
ReconnectMaxDelay time.Duration
|
||||
MaxReconnectAttempts int
|
||||
}
|
||||
|
||||
type WSClient struct {
|
||||
opts WSClientOptions
|
||||
logger *slog.Logger
|
||||
|
||||
writeMu sync.Mutex
|
||||
waitMu sync.Mutex
|
||||
connMu sync.RWMutex
|
||||
|
||||
conn *websocket.Conn
|
||||
waiters map[string]chan wsAck
|
||||
closed bool
|
||||
}
|
||||
|
||||
type wsAck struct {
|
||||
frame WSFrame
|
||||
err error
|
||||
}
|
||||
|
||||
func NewWSClient(opts WSClientOptions) *WSClient {
|
||||
if strings.TrimSpace(opts.URL) == "" {
|
||||
opts.URL = defaultWSURL
|
||||
}
|
||||
if opts.Logger == nil {
|
||||
opts.Logger = slog.Default()
|
||||
}
|
||||
if opts.AckTimeout <= 0 {
|
||||
opts.AckTimeout = 8 * time.Second
|
||||
}
|
||||
if opts.WriteTimeout <= 0 {
|
||||
opts.WriteTimeout = 8 * time.Second
|
||||
}
|
||||
if opts.ReadTimeout <= 0 {
|
||||
opts.ReadTimeout = 70 * time.Second
|
||||
}
|
||||
if opts.HeartbeatInterval <= 0 {
|
||||
opts.HeartbeatInterval = 30 * time.Second
|
||||
}
|
||||
if opts.ReconnectBaseDelay <= 0 {
|
||||
opts.ReconnectBaseDelay = 1 * time.Second
|
||||
}
|
||||
if opts.ReconnectMaxDelay <= 0 {
|
||||
opts.ReconnectMaxDelay = 30 * time.Second
|
||||
}
|
||||
return &WSClient{
|
||||
opts: opts,
|
||||
logger: opts.Logger.With(slog.String("component", "wecom_ws_client")),
|
||||
waiters: make(map[string]chan wsAck),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
|
||||
if err := auth.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
attempt := 0
|
||||
for {
|
||||
err := c.runSession(ctx, auth, onFrame)
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if c.isClosed() {
|
||||
return nil
|
||||
}
|
||||
if c.opts.MaxReconnectAttempts >= 0 && attempt >= c.opts.MaxReconnectAttempts {
|
||||
if err == nil {
|
||||
return fmt.Errorf("wecom websocket reconnect attempts exceeded")
|
||||
}
|
||||
return err
|
||||
}
|
||||
delay := c.backoff(attempt)
|
||||
attempt++
|
||||
c.logger.Warn("wecom websocket session ended; reconnecting",
|
||||
slog.Int("attempt", attempt),
|
||||
slog.Duration("delay", delay),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) runSession(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
|
||||
conn, _, err := c.dial(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.setConn(conn)
|
||||
sessionCtx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
cancel()
|
||||
c.clearConn()
|
||||
c.failAllWaiters(fmt.Errorf("wecom websocket disconnected"))
|
||||
}()
|
||||
|
||||
readErrCh := make(chan error, 1)
|
||||
go c.readLoop(sessionCtx, onFrame, readErrCh)
|
||||
|
||||
if err := c.authenticate(sessionCtx, auth); err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
c.logger.Info("wecom websocket authenticated")
|
||||
|
||||
go c.heartbeatLoop(sessionCtx)
|
||||
|
||||
select {
|
||||
case <-sessionCtx.Done():
|
||||
_ = conn.Close()
|
||||
return sessionCtx.Err()
|
||||
case err := <-readErrCh:
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) dial(ctx context.Context) (*websocket.Conn, *http.Response, error) {
|
||||
dialer := c.opts.Dialer
|
||||
if dialer == nil {
|
||||
dialer = websocket.DefaultDialer
|
||||
}
|
||||
return dialer.DialContext(ctx, c.opts.URL, nil)
|
||||
}
|
||||
|
||||
func (c *WSClient) authenticate(ctx context.Context, auth AuthCredentials) error {
|
||||
frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), SubscribeBody{
|
||||
BotID: strings.TrimSpace(auth.BotID),
|
||||
Secret: strings.TrimSpace(auth.Secret),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ack, err := c.SendWithAck(ctx, frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ack.ErrCode != 0 {
|
||||
return fmt.Errorf("wecom subscribe failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WSClient) heartbeatLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(c.opts.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
frame, err := BuildFrame(WSCmdHeartbeat, NewReqID(WSCmdHeartbeat), nil)
|
||||
if err != nil {
|
||||
c.logger.Error("build heartbeat frame failed", slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
err = c.Send(ctx, frame)
|
||||
if err != nil {
|
||||
c.logger.Warn("wecom websocket heartbeat failed", slog.Any("error", err))
|
||||
if conn := c.getConn(); conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, WSFrame) error, errCh chan<- error) {
|
||||
conn := c.getConn()
|
||||
if conn == nil {
|
||||
errCh <- fmt.Errorf("wecom websocket connection not ready")
|
||||
return
|
||||
}
|
||||
for {
|
||||
if c.opts.ReadTimeout > 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout))
|
||||
}
|
||||
_, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
var frame WSFrame
|
||||
if err := json.Unmarshal(payload, &frame); err != nil {
|
||||
c.logger.Warn("decode websocket frame failed", slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
if c.dispatchAck(frame) {
|
||||
continue
|
||||
}
|
||||
if onFrame == nil {
|
||||
continue
|
||||
}
|
||||
if err := onFrame(ctx, frame); err != nil {
|
||||
c.logger.Warn("wecom onFrame callback returned error", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) Send(ctx context.Context, frame WSFrame) error {
|
||||
if strings.TrimSpace(frame.Headers.ReqID) == "" {
|
||||
return fmt.Errorf("req_id is required")
|
||||
}
|
||||
conn := c.getConn()
|
||||
if conn == nil {
|
||||
return fmt.Errorf("wecom websocket is not connected")
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
if c.opts.WriteTimeout > 0 {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout))
|
||||
}
|
||||
if err := conn.WriteJSON(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WSClient) SendWithAck(ctx context.Context, frame WSFrame) (WSFrame, error) {
|
||||
reqID := strings.TrimSpace(frame.Headers.ReqID)
|
||||
if reqID == "" {
|
||||
return WSFrame{}, fmt.Errorf("req_id is required")
|
||||
}
|
||||
wait := make(chan wsAck, 1)
|
||||
c.waitMu.Lock()
|
||||
c.waiters[reqID] = wait
|
||||
c.waitMu.Unlock()
|
||||
defer func() {
|
||||
c.waitMu.Lock()
|
||||
delete(c.waiters, reqID)
|
||||
c.waitMu.Unlock()
|
||||
}()
|
||||
if err := c.Send(ctx, frame); err != nil {
|
||||
return WSFrame{}, err
|
||||
}
|
||||
timeout := c.opts.AckTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining > 0 && remaining < timeout {
|
||||
timeout = remaining
|
||||
}
|
||||
}
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return WSFrame{}, ctx.Err()
|
||||
case <-timer.C:
|
||||
return WSFrame{}, fmt.Errorf("wait websocket ack timeout for req_id=%s", reqID)
|
||||
case ack := <-wait:
|
||||
if ack.err != nil {
|
||||
return WSFrame{}, ack.err
|
||||
}
|
||||
return ack.frame, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) Close() error {
|
||||
c.connMu.Lock()
|
||||
c.closed = true
|
||||
conn := c.conn
|
||||
c.conn = nil
|
||||
c.connMu.Unlock()
|
||||
c.failAllWaiters(fmt.Errorf("wecom websocket client closed"))
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
func (c *WSClient) Reply(ctx context.Context, reqID string, cmd string, body any) (WSFrame, error) {
|
||||
frame, err := BuildFrame(cmd, reqID, body)
|
||||
if err != nil {
|
||||
return WSFrame{}, err
|
||||
}
|
||||
// WeCom callback reply commands are triggered by inbound callback req_id and may
|
||||
// not always return an explicit ACK frame in production. Waiting for ACK here can
|
||||
// cause false timeouts even when the platform accepts the reply.
|
||||
if isRespondCommand(cmd) {
|
||||
if err := c.Send(ctx, frame); err != nil {
|
||||
return WSFrame{}, err
|
||||
}
|
||||
return WSFrame{}, nil
|
||||
}
|
||||
return c.SendWithAck(ctx, frame)
|
||||
}
|
||||
|
||||
func isRespondCommand(cmd string) bool {
|
||||
switch strings.TrimSpace(cmd) {
|
||||
case WSCmdRespond, WSCmdRespondWelcome, WSCmdRespondUpdate:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) dispatchAck(frame WSFrame) bool {
|
||||
reqID := strings.TrimSpace(frame.Headers.ReqID)
|
||||
if reqID == "" {
|
||||
return false
|
||||
}
|
||||
c.waitMu.Lock()
|
||||
wait, ok := c.waiters[reqID]
|
||||
c.waitMu.Unlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case wait <- wsAck{frame: frame}:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *WSClient) failAllWaiters(cause error) {
|
||||
c.waitMu.Lock()
|
||||
defer c.waitMu.Unlock()
|
||||
for id, wait := range c.waiters {
|
||||
delete(c.waiters, id)
|
||||
select {
|
||||
case wait <- wsAck{err: cause}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) backoff(attempt int) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
delay := c.opts.ReconnectBaseDelay << attempt
|
||||
if delay > c.opts.ReconnectMaxDelay {
|
||||
return c.opts.ReconnectMaxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func (c *WSClient) setConn(conn *websocket.Conn) {
|
||||
c.connMu.Lock()
|
||||
defer c.connMu.Unlock()
|
||||
c.conn = conn
|
||||
}
|
||||
|
||||
func (c *WSClient) clearConn() {
|
||||
c.connMu.Lock()
|
||||
conn := c.conn
|
||||
c.conn = nil
|
||||
c.connMu.Unlock()
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) getConn() *websocket.Conn {
|
||||
c.connMu.RLock()
|
||||
defer c.connMu.RUnlock()
|
||||
return c.conn
|
||||
}
|
||||
|
||||
func (c *WSClient) isClosed() bool {
|
||||
c.connMu.RLock()
|
||||
defer c.connMu.RUnlock()
|
||||
return c.closed
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var connCount atomic.Int32
|
||||
callbackSent := make(chan struct{}, 1)
|
||||
upgrader := websocket.Upgrader{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
n := connCount.Add(1)
|
||||
|
||||
var subscribeFrame WSFrame
|
||||
if err := conn.ReadJSON(&subscribeFrame); err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(WSFrame{
|
||||
Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID},
|
||||
ErrCode: 0,
|
||||
})
|
||||
|
||||
if n == 1 {
|
||||
return
|
||||
}
|
||||
body, _ := json.Marshal(MessageCallbackBody{
|
||||
MsgID: "m1",
|
||||
ChatID: "chat_1",
|
||||
ChatType: "group",
|
||||
CreateTime: time.Now().UnixMilli(),
|
||||
From: CallbackFrom{UserID: "u1"},
|
||||
MsgType: "text",
|
||||
Text: &MessageText{Content: "hello"},
|
||||
})
|
||||
_ = conn.WriteJSON(WSFrame{
|
||||
Cmd: WSCmdMsgCallback,
|
||||
Headers: WSHeaders{ReqID: "cb_req"},
|
||||
Body: body,
|
||||
})
|
||||
select {
|
||||
case callbackSent <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-time.After(100 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
client := NewWSClient(WSClientOptions{
|
||||
URL: wsURL,
|
||||
AckTimeout: 200 * time.Millisecond,
|
||||
HeartbeatInterval: 10 * time.Second,
|
||||
ReconnectBaseDelay: 10 * time.Millisecond,
|
||||
ReconnectMaxDelay: 20 * time.Millisecond,
|
||||
MaxReconnectAttempts: 5,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, func(context.Context, WSFrame) error {
|
||||
cancel()
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-callbackSent:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting callback on reconnected session")
|
||||
}
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err == nil || err != context.Canceled {
|
||||
t.Fatalf("unexpected run error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting run return")
|
||||
}
|
||||
if connCount.Load() < 2 {
|
||||
t.Fatalf("expected reconnect attempts >= 2, got %d", connCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var connCount atomic.Int32
|
||||
heartbeatSeen := make(chan struct{}, 1)
|
||||
upgrader := websocket.Upgrader{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
n := connCount.Add(1)
|
||||
|
||||
var subscribeFrame WSFrame
|
||||
if err := conn.ReadJSON(&subscribeFrame); err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(WSFrame{
|
||||
Headers: WSHeaders{ReqID: subscribeFrame.Headers.ReqID},
|
||||
ErrCode: 0,
|
||||
})
|
||||
if n > 1 {
|
||||
return
|
||||
}
|
||||
var heartbeatFrame WSFrame
|
||||
if err := conn.ReadJSON(&heartbeatFrame); err != nil {
|
||||
return
|
||||
}
|
||||
if heartbeatFrame.Cmd == WSCmdHeartbeat {
|
||||
select {
|
||||
case heartbeatSeen <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
<-time.After(200 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
client := NewWSClient(WSClientOptions{
|
||||
URL: wsURL,
|
||||
HeartbeatInterval: 20 * time.Millisecond,
|
||||
ReconnectBaseDelay: 10 * time.Millisecond,
|
||||
ReconnectMaxDelay: 20 * time.Millisecond,
|
||||
MaxReconnectAttempts: 5,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-heartbeatSeen:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting heartbeat frame")
|
||||
}
|
||||
<-time.After(250 * time.Millisecond)
|
||||
cancel()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err == nil || err != context.Canceled {
|
||||
t.Fatalf("unexpected run error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting run return")
|
||||
}
|
||||
if connCount.Load() < 1 {
|
||||
t.Fatalf("expected at least one session, got %d", connCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user