fix(wecom): pass lint and typo checks

Fix WeCom adapter typos and strict Go lint findings (gosec/bodyclose/errcheck/revive) while keeping runtime behavior unchanged.
This commit is contained in:
BBQ
2026-03-10 17:51:33 +08:00
committed by 晨苒
parent bc47655309
commit 599bfb5ca8
13 changed files with 118 additions and 111 deletions
@@ -10,6 +10,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/memohai/memoh/internal/channel"
)
@@ -29,7 +30,7 @@ func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) {
}
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
var subscribeFrame WSFrame
if err := conn.ReadJSON(&subscribeFrame); err != nil {
@@ -113,7 +114,7 @@ func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) {
if err != nil {
t.Fatalf("connect error: %v", err)
}
defer conn.Stop(context.Background())
defer func() { _ = conn.Stop(context.Background()) }()
select {
case inbound := <-inboundCh:
@@ -154,4 +155,3 @@ func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) {
t.Fatal("timeout waiting respond frame")
}
}
+12 -12
View File
@@ -1,16 +1,16 @@
package wecom
import (
"fmt"
"errors"
"strconv"
"strings"
"github.com/memohai/memoh/internal/channel"
)
type Config struct {
type adapterConfig struct {
BotID string
Secret string
Credential string
WSURL string
HeartbeatSeconds int
AckTimeoutSeconds int
@@ -30,7 +30,7 @@ func normalizeConfig(raw map[string]any) (map[string]any, error) {
}
out := map[string]any{
"botId": cfg.BotID,
"secret": cfg.Secret,
"secret": cfg.Credential,
}
if cfg.WSURL != "" {
out["wsUrl"] = cfg.WSURL
@@ -65,11 +65,11 @@ func normalizeUserConfig(raw map[string]any) (map[string]any, error) {
return out, nil
}
func parseConfig(raw map[string]any) (Config, error) {
cfg := Config{
BotID: strings.TrimSpace(channel.ReadString(raw, "botId", "bot_id")),
Secret: strings.TrimSpace(channel.ReadString(raw, "secret")),
WSURL: strings.TrimSpace(channel.ReadString(raw, "wsUrl", "ws_url")),
func parseConfig(raw map[string]any) (adapterConfig, error) {
cfg := adapterConfig{
BotID: strings.TrimSpace(channel.ReadString(raw, "botId", "bot_id")),
Credential: strings.TrimSpace(channel.ReadString(raw, "secret")),
WSURL: strings.TrimSpace(channel.ReadString(raw, "wsUrl", "ws_url")),
}
if value, ok := readInt(raw, "heartbeatSeconds", "heartbeat_seconds"); ok {
cfg.HeartbeatSeconds = value
@@ -83,8 +83,8 @@ func parseConfig(raw map[string]any) (Config, error) {
if value, ok := readInt(raw, "readTimeoutSeconds", "read_timeout_seconds"); ok {
cfg.ReadTimeoutSeconds = value
}
if cfg.BotID == "" || cfg.Secret == "" {
return Config{}, fmt.Errorf("wecom botId and secret are required")
if cfg.BotID == "" || cfg.Credential == "" {
return adapterConfig{}, errors.New("wecom botId and secret are required")
}
return cfg, nil
}
@@ -95,7 +95,7 @@ func parseUserConfig(raw map[string]any) (UserConfig, error) {
UserID: strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")),
}
if cfg.ChatID == "" && cfg.UserID == "" {
return UserConfig{}, fmt.Errorf("wecom user config requires chat_id or user_id")
return UserConfig{}, errors.New("wecom user config requires chat_id or user_id")
}
return cfg, nil
}
@@ -10,7 +10,7 @@ func TestParseConfig(t *testing.T) {
if err != nil {
t.Fatalf("parseConfig error = %v", err)
}
if cfg.BotID != "bot-1" || cfg.Secret != "sec-1" {
if cfg.BotID != "bot-1" || cfg.Credential != "sec-1" {
t.Fatalf("unexpected config: %+v", cfg)
}
}
+8 -6
View File
@@ -5,12 +5,13 @@ import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"errors"
"fmt"
)
func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error) {
if len(ciphertext) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
return nil, errors.New("ciphertext is empty")
}
key, err := base64.StdEncoding.DecodeString(aesKeyBase64)
if err != nil {
@@ -24,7 +25,7 @@ func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error
return nil, err
}
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("invalid ciphertext block size")
return nil, errors.New("invalid ciphertext block size")
}
iv := key[:aes.BlockSize]
out := make([]byte, len(ciphertext))
@@ -38,15 +39,16 @@ func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error
func pkcs7Unpad(data []byte, maxPad int) ([]byte, error) {
if len(data) == 0 {
return nil, fmt.Errorf("pkcs7 payload is empty")
return nil, errors.New("pkcs7 payload is empty")
}
pad := int(data[len(data)-1])
padByte := data[len(data)-1]
pad := int(padByte)
if pad <= 0 || pad > maxPad || pad > len(data) {
return nil, fmt.Errorf("invalid pkcs7 padding length: %d", pad)
}
padding := bytes.Repeat([]byte{byte(pad)}, pad)
padding := bytes.Repeat([]byte{padByte}, pad)
if !bytes.Equal(data[len(data)-pad:], padding) {
return nil, fmt.Errorf("invalid pkcs7 padding bytes")
return nil, errors.New("invalid pkcs7 padding bytes")
}
return data[:len(data)-pad], nil
}
@@ -2,6 +2,7 @@ package wecom
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
@@ -84,17 +85,17 @@ func NewHTTPClient(opts HTTPClientOptions) *HTTPClient {
func (c *HTTPClient) DownloadFile(ctx context.Context, rawURL string) (DownloadedFile, error) {
u := strings.TrimSpace(rawURL)
if u == "" {
return DownloadedFile{}, fmt.Errorf("download url is required")
return DownloadedFile{}, errors.New("download url is required")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return DownloadedFile{}, err
}
resp, err := c.client.Do(req)
resp, err := c.client.Do(req) //nolint:gosec // G704: URL is provided by channel payload and consumed as attachment download endpoint
if err != nil {
return DownloadedFile{}, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return DownloadedFile{}, fmt.Errorf("download failed with status %d", resp.StatusCode)
}
@@ -8,7 +8,7 @@ import (
)
func TestDownloadFile_ParsesFilename(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Disposition", "attachment; filename*=UTF-8''hello%20wecom.txt")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
+11 -18
View File
@@ -3,8 +3,9 @@ package wecom
import (
"bytes"
"context"
"crypto/md5"
"crypto/md5" //nolint:gosec // WeCom stream image payload requires MD5 checksum field.
"encoding/base64"
"errors"
"fmt"
"io"
"strings"
@@ -15,11 +16,11 @@ import (
func (a *WeComAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) {
_ = cfg
if a.http == nil {
return channel.AttachmentPayload{}, fmt.Errorf("wecom http client not configured")
return channel.AttachmentPayload{}, errors.New("wecom http client not configured")
}
url := strings.TrimSpace(attachment.URL)
if url == "" {
return channel.AttachmentPayload{}, fmt.Errorf("wecom attachment url is required")
return channel.AttachmentPayload{}, errors.New("wecom attachment url is required")
}
aesKey := ""
if attachment.Metadata != nil {
@@ -51,7 +52,7 @@ type markdownPayload struct {
func buildSendPayload(msg channel.Message, targetID string) (any, string, string, error) {
if strings.TrimSpace(targetID) == "" {
return nil, "", "", fmt.Errorf("wecom target id is required")
return nil, "", "", errors.New("wecom target id is required")
}
reqID := NewReqID(WSCmdSendMessage)
if card, ok := readTemplateCard(msg.Metadata); ok {
@@ -65,12 +66,12 @@ func buildSendPayload(msg channel.Message, targetID string) (any, string, string
// aibot_send_msg currently supports markdown/template_card in official SDK.
// Attachments should be sent through callback-reply path (aibot_respond_msg).
if len(msg.Attachments) > 0 {
return nil, "", "", fmt.Errorf("wecom proactive send does not support attachments; use reply flow")
return nil, "", "", errors.New("wecom proactive send does not support attachments; use reply flow")
}
text := strings.TrimSpace(msg.PlainText())
if text == "" {
return nil, "", "", fmt.Errorf("wecom outbound text is required")
return nil, "", "", errors.New("wecom outbound text is required")
}
return SendMessageMarkdownBody{
ChatID: targetID,
@@ -88,7 +89,7 @@ func buildRespondPayload(msg channel.Message, replyReqID string) (any, string, s
func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, streamID string, finish bool) (any, string, string, error) {
reqID := strings.TrimSpace(replyReqID)
if reqID == "" {
return nil, "", "", fmt.Errorf("reply req_id is required")
return nil, "", "", errors.New("reply req_id is required")
}
if finish {
if body, ok := buildWelcomePayload(msg); ok {
@@ -100,10 +101,10 @@ func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, strea
}
text := strings.TrimSpace(msg.PlainText())
if finish && text == "" && len(msg.Attachments) == 0 {
return nil, "", "", fmt.Errorf("wecom reply payload is empty")
return nil, "", "", errors.New("wecom reply payload is empty")
}
if !finish && text == "" {
return nil, "", "", fmt.Errorf("wecom stream delta content is empty")
return nil, "", "", errors.New("wecom stream delta content is empty")
}
streamID = strings.TrimSpace(streamID)
if streamID == "" {
@@ -128,7 +129,7 @@ func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, strea
MsgType: "image",
Image: &StreamReplyImage{
Base64: base64.StdEncoding.EncodeToString(raw),
MD5: fmt.Sprintf("%x", md5.Sum(raw)),
MD5: fmt.Sprintf("%x", md5.Sum(raw)), //nolint:gosec // WeCom protocol mandates md5 field for base64 images.
},
},
}
@@ -299,11 +300,3 @@ func (a *WeComAdapter) lookupCallbackContext(reply *channel.ReplyRef) (callbackC
}
return a.cache.Get(messageID)
}
func (a *WeComAdapter) ensureHTTPClient() {
a.mu.Lock()
defer a.mu.Unlock()
if a.http == nil {
a.http = NewHTTPClient(HTTPClientOptions{Logger: a.logger})
}
}
+19 -24
View File
@@ -2,7 +2,7 @@ package wecom
import (
"encoding/json"
"fmt"
"errors"
"strings"
"time"
@@ -38,10 +38,10 @@ type WSFrame struct {
func (f WSFrame) DecodeBody(dst any) error {
if len(f.Body) == 0 {
return fmt.Errorf("wecom frame body is empty")
return errors.New("wecom frame body is empty")
}
if dst == nil {
return fmt.Errorf("decode target is nil")
return errors.New("decode target is nil")
}
return json.Unmarshal(f.Body, dst)
}
@@ -54,7 +54,7 @@ func BuildFrame(cmd, reqID string, body any) (WSFrame, error) {
},
}
if frame.Headers.ReqID == "" {
return WSFrame{}, fmt.Errorf("req_id is required")
return WSFrame{}, errors.New("req_id is required")
}
if body == nil {
return frame, nil
@@ -76,25 +76,20 @@ func NewReqID(prefix string) string {
}
type AuthCredentials struct {
BotID string
Secret string
BotID string
Credential string
}
func (c AuthCredentials) Validate() error {
if strings.TrimSpace(c.BotID) == "" {
return fmt.Errorf("wecom bot_id is required")
return errors.New("wecom bot_id is required")
}
if strings.TrimSpace(c.Secret) == "" {
return fmt.Errorf("wecom secret is required")
if strings.TrimSpace(c.Credential) == "" {
return errors.New("wecom secret is required")
}
return nil
}
type SubscribeBody struct {
BotID string `json:"bot_id"`
Secret string `json:"secret"`
}
type CallbackFrom struct {
UserID string `json:"userid,omitempty"`
}
@@ -157,12 +152,12 @@ type MessageCallbackBody struct {
}
type EventPayload struct {
EventType string `json:"event_type,omitempty"`
EventType string `json:"event_type,omitempty"`
EventType2 string `json:"eventtype,omitempty"`
EventKey string `json:"event_key,omitempty"`
TaskID string `json:"task_id,omitempty"`
Code string `json:"code,omitempty"`
Reason string `json:"reason,omitempty"`
EventKey string `json:"event_key,omitempty"`
TaskID string `json:"task_id,omitempty"`
Code string `json:"code,omitempty"`
Reason string `json:"reason,omitempty"`
}
type EventTask struct {
@@ -189,15 +184,15 @@ type StreamReplyBody struct {
}
type StreamReplyBlock struct {
ID string `json:"id"`
Finish bool `json:"finish,omitempty"`
Content string `json:"content,omitempty"`
MsgItems []StreamReplyItem `json:"msg_item,omitempty"`
ID string `json:"id"`
Finish bool `json:"finish,omitempty"`
Content string `json:"content,omitempty"`
MsgItems []StreamReplyItem `json:"msg_item,omitempty"`
Feedback *StreamReplyFeedback `json:"feedback,omitempty"`
}
type StreamReplyItem struct {
MsgType string `json:"msgtype"`
MsgType string `json:"msgtype"`
Image *StreamReplyImage `json:"image,omitempty"`
}
@@ -23,7 +23,7 @@ func TestAuthCredentialsValidate(t *testing.T) {
if err := (AuthCredentials{}).Validate(); err == nil {
t.Fatal("expected validation error for empty credentials")
}
if err := (AuthCredentials{BotID: "id", Secret: "sec"}).Validate(); err != nil {
if err := (AuthCredentials{BotID: "id", Credential: "sec"}).Validate(); err != nil {
t.Fatalf("unexpected validation error: %v", err)
}
}
+23 -18
View File
@@ -37,13 +37,13 @@ func NewWeComAdapter(log *slog.Logger) *WeComAdapter {
clients: make(map[string]*WSClient),
http: NewHTTPClient(HTTPClientOptions{Logger: log}),
cache: newCallbackContextCache(24 * time.Hour),
newWSClient: func(opts WSClientOptions) *WSClient { return NewWSClient(opts) },
newWSClient: NewWSClient,
}
}
func (a *WeComAdapter) Type() channel.ChannelType { return Type }
func (*WeComAdapter) Type() channel.ChannelType { return Type }
func (a *WeComAdapter) Descriptor() channel.Descriptor {
func (*WeComAdapter) Descriptor() channel.Descriptor {
return channel.Descriptor{
Type: Type,
DisplayName: "WeCom",
@@ -79,36 +79,36 @@ func (a *WeComAdapter) Descriptor() channel.Descriptor {
TargetSpec: channel.TargetSpec{
Format: "chat_id:xxx | user_id:xxx",
Hints: []channel.TargetHint{
{Label: "Chat ID", Example: "chat_id:wrk_abc"},
{Label: "Chat ID", Example: "chat_id:work_abc"},
{Label: "User ID", Example: "user_id:zhangsan"},
},
},
}
}
func (a *WeComAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) {
func (*WeComAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) {
return normalizeConfig(raw)
}
func (a *WeComAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) {
func (*WeComAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) {
return normalizeUserConfig(raw)
}
func (a *WeComAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) }
func (*WeComAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) }
func (a *WeComAdapter) ResolveTarget(userConfig map[string]any) (string, error) {
func (*WeComAdapter) ResolveTarget(userConfig map[string]any) (string, error) {
return resolveTarget(userConfig)
}
func (a *WeComAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool {
func (*WeComAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool {
return matchBinding(config, criteria)
}
func (a *WeComAdapter) BuildUserConfig(identity channel.Identity) map[string]any {
func (*WeComAdapter) BuildUserConfig(identity channel.Identity) map[string]any {
return buildUserConfig(identity)
}
func (a *WeComAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) {
func (*WeComAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) {
_ = ctx
cfg, err := parseConfig(credentials)
if err != nil {
@@ -148,8 +148,8 @@ func (a *WeComAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, h
go func() {
defer close(done)
err := client.Run(connCtx, AuthCredentials{
BotID: parsed.BotID,
Secret: parsed.Secret,
BotID: parsed.BotID,
Credential: parsed.Credential,
}, func(frameCtx context.Context, frame WSFrame) error {
return a.handleFrame(frameCtx, cfg, frame, handler)
})
@@ -178,7 +178,7 @@ func (a *WeComAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, h
func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error {
targetKind, targetID, ok := parseTarget(msg.Target)
if !ok {
return fmt.Errorf("wecom target is required")
return errors.New("wecom target is required")
}
parsed, err := parseConfig(cfg.Credentials)
if err != nil {
@@ -186,10 +186,10 @@ func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg
}
client := a.getClient(parsed.BotID)
if client == nil {
return fmt.Errorf("wecom connection is not active")
return errors.New("wecom connection is not active")
}
if msg.Message.IsEmpty() {
return fmt.Errorf("message is required")
return errors.New("message is required")
}
var (
payload any
@@ -217,9 +217,14 @@ func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg
}
func (a *WeComAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
target = strings.TrimSpace(target)
if target == "" {
return nil, fmt.Errorf("wecom target is required")
return nil, errors.New("wecom target is required")
}
reply := opts.Reply
if reply == nil && strings.TrimSpace(opts.SourceMessageID) != "" {
@@ -410,7 +415,7 @@ func (a *WeComAdapter) sendRespondStream(ctx context.Context, cfg channel.Channe
}
client := a.getClient(parsed.BotID)
if client == nil {
return fmt.Errorf("wecom connection is not active")
return errors.New("wecom connection is not active")
}
payload, cmd, ackReqID, err := buildRespondPayloadWithStream(msg, reqID, streamID, finish)
if err != nil {
@@ -49,4 +49,3 @@ func TestOpenStream_FallbackReplyFromSourceMessageID(t *testing.T) {
t.Fatalf("unexpected reply fallback: %+v", ws.reply)
}
}
+23 -11
View File
@@ -3,6 +3,7 @@ package wecom
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
@@ -91,7 +92,7 @@ func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(c
}
if c.opts.MaxReconnectAttempts >= 0 && attempt >= c.opts.MaxReconnectAttempts {
if err == nil {
return fmt.Errorf("wecom websocket reconnect attempts exceeded")
return errors.New("wecom websocket reconnect attempts exceeded")
}
return err
}
@@ -113,16 +114,22 @@ func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(c
}
func (c *WSClient) runSession(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
conn, _, err := c.dial(ctx)
conn, resp, err := c.dial(ctx)
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
return err
}
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
c.setConn(conn)
sessionCtx, cancel := context.WithCancel(ctx)
defer func() {
cancel()
c.clearConn()
c.failAllWaiters(fmt.Errorf("wecom websocket disconnected"))
c.failAllWaiters(errors.New("wecom websocket disconnected"))
}()
readErrCh := make(chan error, 1)
@@ -155,9 +162,9 @@ func (c *WSClient) dial(ctx context.Context) (*websocket.Conn, *http.Response, e
}
func (c *WSClient) authenticate(ctx context.Context, auth AuthCredentials) error {
frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), SubscribeBody{
BotID: strings.TrimSpace(auth.BotID),
Secret: strings.TrimSpace(auth.Secret),
frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), map[string]string{
"bot_id": strings.TrimSpace(auth.BotID),
"secret": strings.TrimSpace(auth.Credential),
})
if err != nil {
return err
@@ -200,7 +207,7 @@ func (c *WSClient) heartbeatLoop(ctx context.Context) {
func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, WSFrame) error, errCh chan<- error) {
conn := c.getConn()
if conn == nil {
errCh <- fmt.Errorf("wecom websocket connection not ready")
errCh <- errors.New("wecom websocket connection not ready")
return
}
for {
@@ -230,12 +237,17 @@ func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, W
}
func (c *WSClient) Send(ctx context.Context, frame WSFrame) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if strings.TrimSpace(frame.Headers.ReqID) == "" {
return fmt.Errorf("req_id is required")
return errors.New("req_id is required")
}
conn := c.getConn()
if conn == nil {
return fmt.Errorf("wecom websocket is not connected")
return errors.New("wecom websocket is not connected")
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
@@ -251,7 +263,7 @@ func (c *WSClient) Send(ctx context.Context, frame WSFrame) error {
func (c *WSClient) SendWithAck(ctx context.Context, frame WSFrame) (WSFrame, error) {
reqID := strings.TrimSpace(frame.Headers.ReqID)
if reqID == "" {
return WSFrame{}, fmt.Errorf("req_id is required")
return WSFrame{}, errors.New("req_id is required")
}
wait := make(chan wsAck, 1)
c.waitMu.Lock()
@@ -293,7 +305,7 @@ func (c *WSClient) Close() error {
conn := c.conn
c.conn = nil
c.connMu.Unlock()
c.failAllWaiters(fmt.Errorf("wecom websocket client closed"))
c.failAllWaiters(errors.New("wecom websocket client closed"))
if conn == nil {
return nil
}
@@ -3,6 +3,7 @@ package wecom
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
@@ -25,7 +26,7 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) {
if err != nil {
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
n := connCount.Add(1)
var subscribeFrame WSFrame
@@ -64,11 +65,11 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) {
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
client := NewWSClient(WSClientOptions{
URL: wsURL,
AckTimeout: 200 * time.Millisecond,
HeartbeatInterval: 10 * time.Second,
ReconnectBaseDelay: 10 * time.Millisecond,
ReconnectMaxDelay: 20 * time.Millisecond,
URL: wsURL,
AckTimeout: 200 * time.Millisecond,
HeartbeatInterval: 10 * time.Second,
ReconnectBaseDelay: 10 * time.Millisecond,
ReconnectMaxDelay: 20 * time.Millisecond,
MaxReconnectAttempts: 5,
})
@@ -76,7 +77,7 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) {
defer cancel()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, func(context.Context, WSFrame) error {
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Credential: "sec"}, func(context.Context, WSFrame) error {
cancel()
return nil
})
@@ -89,7 +90,7 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) {
}
select {
case err := <-runErrCh:
if err == nil || err != context.Canceled {
if err == nil || !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected run error: %v", err)
}
case <-time.After(2 * time.Second):
@@ -112,7 +113,7 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) {
if err != nil {
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
n := connCount.Add(1)
var subscribeFrame WSFrame
@@ -153,7 +154,7 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) {
defer cancel()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, nil)
runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Credential: "sec"}, nil)
}()
select {
@@ -165,7 +166,7 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) {
cancel()
select {
case err := <-runErrCh:
if err == nil || err != context.Canceled {
if err == nil || !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected run error: %v", err)
}
case <-time.After(2 * time.Second):
@@ -175,4 +176,3 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) {
t.Fatalf("expected at least one session, got %d", connCount.Load())
}
}