From 599bfb5ca8b519d9cf0d5be8081bfb53d6157eae Mon Sep 17 00:00:00 2001 From: BBQ Date: Tue, 10 Mar 2026 17:51:33 +0800 Subject: [PATCH] fix(wecom): pass lint and typo checks Fix WeCom adapter typos and strict Go lint findings (gosec/bodyclose/errcheck/revive) while keeping runtime behavior unchanged. --- .../wecom/adapter_integration_test.go | 6 +-- internal/channel/adapters/wecom/config.go | 24 +++++------ .../channel/adapters/wecom/config_test.go | 2 +- internal/channel/adapters/wecom/crypto.go | 14 +++--- .../channel/adapters/wecom/http_client.go | 7 +-- .../adapters/wecom/http_client_test.go | 2 +- internal/channel/adapters/wecom/outbound.go | 29 +++++-------- internal/channel/adapters/wecom/protocol.go | 43 ++++++++----------- .../channel/adapters/wecom/protocol_test.go | 2 +- internal/channel/adapters/wecom/wecom.go | 41 ++++++++++-------- internal/channel/adapters/wecom/wecom_test.go | 1 - internal/channel/adapters/wecom/ws_client.go | 34 ++++++++++----- .../channel/adapters/wecom/ws_client_test.go | 24 +++++------ 13 files changed, 118 insertions(+), 111 deletions(-) diff --git a/internal/channel/adapters/wecom/adapter_integration_test.go b/internal/channel/adapters/wecom/adapter_integration_test.go index d90f058e..c3d8cc3f 100644 --- a/internal/channel/adapters/wecom/adapter_integration_test.go +++ b/internal/channel/adapters/wecom/adapter_integration_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/memohai/memoh/internal/channel" ) @@ -29,7 +30,7 @@ func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) { } return } - defer conn.Close() + defer func() { _ = conn.Close() }() var subscribeFrame WSFrame if err := conn.ReadJSON(&subscribeFrame); err != nil { @@ -113,7 +114,7 @@ func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) { if err != nil { t.Fatalf("connect error: %v", err) } - defer conn.Stop(context.Background()) + defer func() { _ = conn.Stop(context.Background()) }() select { case inbound := <-inboundCh: @@ -154,4 +155,3 @@ func TestWeComAdapter_ReplyUsesRespondCmd(t *testing.T) { t.Fatal("timeout waiting respond frame") } } - diff --git a/internal/channel/adapters/wecom/config.go b/internal/channel/adapters/wecom/config.go index 4c595e60..e8744c28 100644 --- a/internal/channel/adapters/wecom/config.go +++ b/internal/channel/adapters/wecom/config.go @@ -1,16 +1,16 @@ package wecom import ( - "fmt" + "errors" "strconv" "strings" "github.com/memohai/memoh/internal/channel" ) -type Config struct { +type adapterConfig struct { BotID string - Secret string + Credential string WSURL string HeartbeatSeconds int AckTimeoutSeconds int @@ -30,7 +30,7 @@ func normalizeConfig(raw map[string]any) (map[string]any, error) { } out := map[string]any{ "botId": cfg.BotID, - "secret": cfg.Secret, + "secret": cfg.Credential, } if cfg.WSURL != "" { out["wsUrl"] = cfg.WSURL @@ -65,11 +65,11 @@ func normalizeUserConfig(raw map[string]any) (map[string]any, error) { return out, nil } -func parseConfig(raw map[string]any) (Config, error) { - cfg := Config{ - BotID: strings.TrimSpace(channel.ReadString(raw, "botId", "bot_id")), - Secret: strings.TrimSpace(channel.ReadString(raw, "secret")), - WSURL: strings.TrimSpace(channel.ReadString(raw, "wsUrl", "ws_url")), +func parseConfig(raw map[string]any) (adapterConfig, error) { + cfg := adapterConfig{ + BotID: strings.TrimSpace(channel.ReadString(raw, "botId", "bot_id")), + Credential: strings.TrimSpace(channel.ReadString(raw, "secret")), + WSURL: strings.TrimSpace(channel.ReadString(raw, "wsUrl", "ws_url")), } if value, ok := readInt(raw, "heartbeatSeconds", "heartbeat_seconds"); ok { cfg.HeartbeatSeconds = value @@ -83,8 +83,8 @@ func parseConfig(raw map[string]any) (Config, error) { if value, ok := readInt(raw, "readTimeoutSeconds", "read_timeout_seconds"); ok { cfg.ReadTimeoutSeconds = value } - if cfg.BotID == "" || cfg.Secret == "" { - return Config{}, fmt.Errorf("wecom botId and secret are required") + if cfg.BotID == "" || cfg.Credential == "" { + return adapterConfig{}, errors.New("wecom botId and secret are required") } return cfg, nil } @@ -95,7 +95,7 @@ func parseUserConfig(raw map[string]any) (UserConfig, error) { UserID: strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")), } if cfg.ChatID == "" && cfg.UserID == "" { - return UserConfig{}, fmt.Errorf("wecom user config requires chat_id or user_id") + return UserConfig{}, errors.New("wecom user config requires chat_id or user_id") } return cfg, nil } diff --git a/internal/channel/adapters/wecom/config_test.go b/internal/channel/adapters/wecom/config_test.go index ea89aa8c..5426da5e 100644 --- a/internal/channel/adapters/wecom/config_test.go +++ b/internal/channel/adapters/wecom/config_test.go @@ -10,7 +10,7 @@ func TestParseConfig(t *testing.T) { if err != nil { t.Fatalf("parseConfig error = %v", err) } - if cfg.BotID != "bot-1" || cfg.Secret != "sec-1" { + if cfg.BotID != "bot-1" || cfg.Credential != "sec-1" { t.Fatalf("unexpected config: %+v", cfg) } } diff --git a/internal/channel/adapters/wecom/crypto.go b/internal/channel/adapters/wecom/crypto.go index ca60003e..761de5ad 100644 --- a/internal/channel/adapters/wecom/crypto.go +++ b/internal/channel/adapters/wecom/crypto.go @@ -5,12 +5,13 @@ import ( "crypto/aes" "crypto/cipher" "encoding/base64" + "errors" "fmt" ) func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error) { if len(ciphertext) == 0 { - return nil, fmt.Errorf("ciphertext is empty") + return nil, errors.New("ciphertext is empty") } key, err := base64.StdEncoding.DecodeString(aesKeyBase64) if err != nil { @@ -24,7 +25,7 @@ func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error return nil, err } if len(ciphertext)%aes.BlockSize != 0 { - return nil, fmt.Errorf("invalid ciphertext block size") + return nil, errors.New("invalid ciphertext block size") } iv := key[:aes.BlockSize] out := make([]byte, len(ciphertext)) @@ -38,15 +39,16 @@ func DecryptFileAES256CBC(ciphertext []byte, aesKeyBase64 string) ([]byte, error func pkcs7Unpad(data []byte, maxPad int) ([]byte, error) { if len(data) == 0 { - return nil, fmt.Errorf("pkcs7 payload is empty") + return nil, errors.New("pkcs7 payload is empty") } - pad := int(data[len(data)-1]) + padByte := data[len(data)-1] + pad := int(padByte) if pad <= 0 || pad > maxPad || pad > len(data) { return nil, fmt.Errorf("invalid pkcs7 padding length: %d", pad) } - padding := bytes.Repeat([]byte{byte(pad)}, pad) + padding := bytes.Repeat([]byte{padByte}, pad) if !bytes.Equal(data[len(data)-pad:], padding) { - return nil, fmt.Errorf("invalid pkcs7 padding bytes") + return nil, errors.New("invalid pkcs7 padding bytes") } return data[:len(data)-pad], nil } diff --git a/internal/channel/adapters/wecom/http_client.go b/internal/channel/adapters/wecom/http_client.go index 091048b7..e5098ca8 100644 --- a/internal/channel/adapters/wecom/http_client.go +++ b/internal/channel/adapters/wecom/http_client.go @@ -2,6 +2,7 @@ package wecom import ( "context" + "errors" "fmt" "io" "log/slog" @@ -84,17 +85,17 @@ func NewHTTPClient(opts HTTPClientOptions) *HTTPClient { func (c *HTTPClient) DownloadFile(ctx context.Context, rawURL string) (DownloadedFile, error) { u := strings.TrimSpace(rawURL) if u == "" { - return DownloadedFile{}, fmt.Errorf("download url is required") + return DownloadedFile{}, errors.New("download url is required") } req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) if err != nil { return DownloadedFile{}, err } - resp, err := c.client.Do(req) + resp, err := c.client.Do(req) //nolint:gosec // G704: URL is provided by channel payload and consumed as attachment download endpoint if err != nil { return DownloadedFile{}, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return DownloadedFile{}, fmt.Errorf("download failed with status %d", resp.StatusCode) } diff --git a/internal/channel/adapters/wecom/http_client_test.go b/internal/channel/adapters/wecom/http_client_test.go index 428d335e..6adeb799 100644 --- a/internal/channel/adapters/wecom/http_client_test.go +++ b/internal/channel/adapters/wecom/http_client_test.go @@ -8,7 +8,7 @@ import ( ) func TestDownloadFile_ParsesFilename(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Disposition", "attachment; filename*=UTF-8''hello%20wecom.txt") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) diff --git a/internal/channel/adapters/wecom/outbound.go b/internal/channel/adapters/wecom/outbound.go index 6ac1d7c4..901a3759 100644 --- a/internal/channel/adapters/wecom/outbound.go +++ b/internal/channel/adapters/wecom/outbound.go @@ -3,8 +3,9 @@ package wecom import ( "bytes" "context" - "crypto/md5" + "crypto/md5" //nolint:gosec // WeCom stream image payload requires MD5 checksum field. "encoding/base64" + "errors" "fmt" "io" "strings" @@ -15,11 +16,11 @@ import ( func (a *WeComAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) { _ = cfg if a.http == nil { - return channel.AttachmentPayload{}, fmt.Errorf("wecom http client not configured") + return channel.AttachmentPayload{}, errors.New("wecom http client not configured") } url := strings.TrimSpace(attachment.URL) if url == "" { - return channel.AttachmentPayload{}, fmt.Errorf("wecom attachment url is required") + return channel.AttachmentPayload{}, errors.New("wecom attachment url is required") } aesKey := "" if attachment.Metadata != nil { @@ -51,7 +52,7 @@ type markdownPayload struct { func buildSendPayload(msg channel.Message, targetID string) (any, string, string, error) { if strings.TrimSpace(targetID) == "" { - return nil, "", "", fmt.Errorf("wecom target id is required") + return nil, "", "", errors.New("wecom target id is required") } reqID := NewReqID(WSCmdSendMessage) if card, ok := readTemplateCard(msg.Metadata); ok { @@ -65,12 +66,12 @@ func buildSendPayload(msg channel.Message, targetID string) (any, string, string // aibot_send_msg currently supports markdown/template_card in official SDK. // Attachments should be sent through callback-reply path (aibot_respond_msg). if len(msg.Attachments) > 0 { - return nil, "", "", fmt.Errorf("wecom proactive send does not support attachments; use reply flow") + return nil, "", "", errors.New("wecom proactive send does not support attachments; use reply flow") } text := strings.TrimSpace(msg.PlainText()) if text == "" { - return nil, "", "", fmt.Errorf("wecom outbound text is required") + return nil, "", "", errors.New("wecom outbound text is required") } return SendMessageMarkdownBody{ ChatID: targetID, @@ -88,7 +89,7 @@ func buildRespondPayload(msg channel.Message, replyReqID string) (any, string, s func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, streamID string, finish bool) (any, string, string, error) { reqID := strings.TrimSpace(replyReqID) if reqID == "" { - return nil, "", "", fmt.Errorf("reply req_id is required") + return nil, "", "", errors.New("reply req_id is required") } if finish { if body, ok := buildWelcomePayload(msg); ok { @@ -100,10 +101,10 @@ func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, strea } text := strings.TrimSpace(msg.PlainText()) if finish && text == "" && len(msg.Attachments) == 0 { - return nil, "", "", fmt.Errorf("wecom reply payload is empty") + return nil, "", "", errors.New("wecom reply payload is empty") } if !finish && text == "" { - return nil, "", "", fmt.Errorf("wecom stream delta content is empty") + return nil, "", "", errors.New("wecom stream delta content is empty") } streamID = strings.TrimSpace(streamID) if streamID == "" { @@ -128,7 +129,7 @@ func buildRespondPayloadWithStream(msg channel.Message, replyReqID string, strea MsgType: "image", Image: &StreamReplyImage{ Base64: base64.StdEncoding.EncodeToString(raw), - MD5: fmt.Sprintf("%x", md5.Sum(raw)), + MD5: fmt.Sprintf("%x", md5.Sum(raw)), //nolint:gosec // WeCom protocol mandates md5 field for base64 images. }, }, } @@ -299,11 +300,3 @@ func (a *WeComAdapter) lookupCallbackContext(reply *channel.ReplyRef) (callbackC } return a.cache.Get(messageID) } - -func (a *WeComAdapter) ensureHTTPClient() { - a.mu.Lock() - defer a.mu.Unlock() - if a.http == nil { - a.http = NewHTTPClient(HTTPClientOptions{Logger: a.logger}) - } -} diff --git a/internal/channel/adapters/wecom/protocol.go b/internal/channel/adapters/wecom/protocol.go index 69e2ca9b..a8b311f0 100644 --- a/internal/channel/adapters/wecom/protocol.go +++ b/internal/channel/adapters/wecom/protocol.go @@ -2,7 +2,7 @@ package wecom import ( "encoding/json" - "fmt" + "errors" "strings" "time" @@ -38,10 +38,10 @@ type WSFrame struct { func (f WSFrame) DecodeBody(dst any) error { if len(f.Body) == 0 { - return fmt.Errorf("wecom frame body is empty") + return errors.New("wecom frame body is empty") } if dst == nil { - return fmt.Errorf("decode target is nil") + return errors.New("decode target is nil") } return json.Unmarshal(f.Body, dst) } @@ -54,7 +54,7 @@ func BuildFrame(cmd, reqID string, body any) (WSFrame, error) { }, } if frame.Headers.ReqID == "" { - return WSFrame{}, fmt.Errorf("req_id is required") + return WSFrame{}, errors.New("req_id is required") } if body == nil { return frame, nil @@ -76,25 +76,20 @@ func NewReqID(prefix string) string { } type AuthCredentials struct { - BotID string - Secret string + BotID string + Credential string } func (c AuthCredentials) Validate() error { if strings.TrimSpace(c.BotID) == "" { - return fmt.Errorf("wecom bot_id is required") + return errors.New("wecom bot_id is required") } - if strings.TrimSpace(c.Secret) == "" { - return fmt.Errorf("wecom secret is required") + if strings.TrimSpace(c.Credential) == "" { + return errors.New("wecom secret is required") } return nil } -type SubscribeBody struct { - BotID string `json:"bot_id"` - Secret string `json:"secret"` -} - type CallbackFrom struct { UserID string `json:"userid,omitempty"` } @@ -157,12 +152,12 @@ type MessageCallbackBody struct { } type EventPayload struct { - EventType string `json:"event_type,omitempty"` + EventType string `json:"event_type,omitempty"` EventType2 string `json:"eventtype,omitempty"` - EventKey string `json:"event_key,omitempty"` - TaskID string `json:"task_id,omitempty"` - Code string `json:"code,omitempty"` - Reason string `json:"reason,omitempty"` + EventKey string `json:"event_key,omitempty"` + TaskID string `json:"task_id,omitempty"` + Code string `json:"code,omitempty"` + Reason string `json:"reason,omitempty"` } type EventTask struct { @@ -189,15 +184,15 @@ type StreamReplyBody struct { } type StreamReplyBlock struct { - ID string `json:"id"` - Finish bool `json:"finish,omitempty"` - Content string `json:"content,omitempty"` - MsgItems []StreamReplyItem `json:"msg_item,omitempty"` + ID string `json:"id"` + Finish bool `json:"finish,omitempty"` + Content string `json:"content,omitempty"` + MsgItems []StreamReplyItem `json:"msg_item,omitempty"` Feedback *StreamReplyFeedback `json:"feedback,omitempty"` } type StreamReplyItem struct { - MsgType string `json:"msgtype"` + MsgType string `json:"msgtype"` Image *StreamReplyImage `json:"image,omitempty"` } diff --git a/internal/channel/adapters/wecom/protocol_test.go b/internal/channel/adapters/wecom/protocol_test.go index a002e462..c14528bc 100644 --- a/internal/channel/adapters/wecom/protocol_test.go +++ b/internal/channel/adapters/wecom/protocol_test.go @@ -23,7 +23,7 @@ func TestAuthCredentialsValidate(t *testing.T) { if err := (AuthCredentials{}).Validate(); err == nil { t.Fatal("expected validation error for empty credentials") } - if err := (AuthCredentials{BotID: "id", Secret: "sec"}).Validate(); err != nil { + if err := (AuthCredentials{BotID: "id", Credential: "sec"}).Validate(); err != nil { t.Fatalf("unexpected validation error: %v", err) } } diff --git a/internal/channel/adapters/wecom/wecom.go b/internal/channel/adapters/wecom/wecom.go index 81df25a7..a5a479bf 100644 --- a/internal/channel/adapters/wecom/wecom.go +++ b/internal/channel/adapters/wecom/wecom.go @@ -37,13 +37,13 @@ func NewWeComAdapter(log *slog.Logger) *WeComAdapter { clients: make(map[string]*WSClient), http: NewHTTPClient(HTTPClientOptions{Logger: log}), cache: newCallbackContextCache(24 * time.Hour), - newWSClient: func(opts WSClientOptions) *WSClient { return NewWSClient(opts) }, + newWSClient: NewWSClient, } } -func (a *WeComAdapter) Type() channel.ChannelType { return Type } +func (*WeComAdapter) Type() channel.ChannelType { return Type } -func (a *WeComAdapter) Descriptor() channel.Descriptor { +func (*WeComAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: Type, DisplayName: "WeCom", @@ -79,36 +79,36 @@ func (a *WeComAdapter) Descriptor() channel.Descriptor { TargetSpec: channel.TargetSpec{ Format: "chat_id:xxx | user_id:xxx", Hints: []channel.TargetHint{ - {Label: "Chat ID", Example: "chat_id:wrk_abc"}, + {Label: "Chat ID", Example: "chat_id:work_abc"}, {Label: "User ID", Example: "user_id:zhangsan"}, }, }, } } -func (a *WeComAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (*WeComAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) } -func (a *WeComAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func (*WeComAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return normalizeUserConfig(raw) } -func (a *WeComAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } +func (*WeComAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } -func (a *WeComAdapter) ResolveTarget(userConfig map[string]any) (string, error) { +func (*WeComAdapter) ResolveTarget(userConfig map[string]any) (string, error) { return resolveTarget(userConfig) } -func (a *WeComAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { +func (*WeComAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { return matchBinding(config, criteria) } -func (a *WeComAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (*WeComAdapter) BuildUserConfig(identity channel.Identity) map[string]any { return buildUserConfig(identity) } -func (a *WeComAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { +func (*WeComAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { _ = ctx cfg, err := parseConfig(credentials) if err != nil { @@ -148,8 +148,8 @@ func (a *WeComAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, h go func() { defer close(done) err := client.Run(connCtx, AuthCredentials{ - BotID: parsed.BotID, - Secret: parsed.Secret, + BotID: parsed.BotID, + Credential: parsed.Credential, }, func(frameCtx context.Context, frame WSFrame) error { return a.handleFrame(frameCtx, cfg, frame, handler) }) @@ -178,7 +178,7 @@ func (a *WeComAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, h func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { targetKind, targetID, ok := parseTarget(msg.Target) if !ok { - return fmt.Errorf("wecom target is required") + return errors.New("wecom target is required") } parsed, err := parseConfig(cfg.Credentials) if err != nil { @@ -186,10 +186,10 @@ func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg } client := a.getClient(parsed.BotID) if client == nil { - return fmt.Errorf("wecom connection is not active") + return errors.New("wecom connection is not active") } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } var ( payload any @@ -217,9 +217,14 @@ func (a *WeComAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg } func (a *WeComAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("wecom target is required") + return nil, errors.New("wecom target is required") } reply := opts.Reply if reply == nil && strings.TrimSpace(opts.SourceMessageID) != "" { @@ -410,7 +415,7 @@ func (a *WeComAdapter) sendRespondStream(ctx context.Context, cfg channel.Channe } client := a.getClient(parsed.BotID) if client == nil { - return fmt.Errorf("wecom connection is not active") + return errors.New("wecom connection is not active") } payload, cmd, ackReqID, err := buildRespondPayloadWithStream(msg, reqID, streamID, finish) if err != nil { diff --git a/internal/channel/adapters/wecom/wecom_test.go b/internal/channel/adapters/wecom/wecom_test.go index e58bfc1f..73e3e5b4 100644 --- a/internal/channel/adapters/wecom/wecom_test.go +++ b/internal/channel/adapters/wecom/wecom_test.go @@ -49,4 +49,3 @@ func TestOpenStream_FallbackReplyFromSourceMessageID(t *testing.T) { t.Fatalf("unexpected reply fallback: %+v", ws.reply) } } - diff --git a/internal/channel/adapters/wecom/ws_client.go b/internal/channel/adapters/wecom/ws_client.go index 04edcf42..3ea29b41 100644 --- a/internal/channel/adapters/wecom/ws_client.go +++ b/internal/channel/adapters/wecom/ws_client.go @@ -3,6 +3,7 @@ package wecom import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "net/http" @@ -91,7 +92,7 @@ func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(c } if c.opts.MaxReconnectAttempts >= 0 && attempt >= c.opts.MaxReconnectAttempts { if err == nil { - return fmt.Errorf("wecom websocket reconnect attempts exceeded") + return errors.New("wecom websocket reconnect attempts exceeded") } return err } @@ -113,16 +114,22 @@ func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(c } func (c *WSClient) runSession(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error { - conn, _, err := c.dial(ctx) + conn, resp, err := c.dial(ctx) if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } return err } + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } c.setConn(conn) sessionCtx, cancel := context.WithCancel(ctx) defer func() { cancel() c.clearConn() - c.failAllWaiters(fmt.Errorf("wecom websocket disconnected")) + c.failAllWaiters(errors.New("wecom websocket disconnected")) }() readErrCh := make(chan error, 1) @@ -155,9 +162,9 @@ func (c *WSClient) dial(ctx context.Context) (*websocket.Conn, *http.Response, e } func (c *WSClient) authenticate(ctx context.Context, auth AuthCredentials) error { - frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), SubscribeBody{ - BotID: strings.TrimSpace(auth.BotID), - Secret: strings.TrimSpace(auth.Secret), + frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), map[string]string{ + "bot_id": strings.TrimSpace(auth.BotID), + "secret": strings.TrimSpace(auth.Credential), }) if err != nil { return err @@ -200,7 +207,7 @@ func (c *WSClient) heartbeatLoop(ctx context.Context) { func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, WSFrame) error, errCh chan<- error) { conn := c.getConn() if conn == nil { - errCh <- fmt.Errorf("wecom websocket connection not ready") + errCh <- errors.New("wecom websocket connection not ready") return } for { @@ -230,12 +237,17 @@ func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, W } func (c *WSClient) Send(ctx context.Context, frame WSFrame) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } if strings.TrimSpace(frame.Headers.ReqID) == "" { - return fmt.Errorf("req_id is required") + return errors.New("req_id is required") } conn := c.getConn() if conn == nil { - return fmt.Errorf("wecom websocket is not connected") + return errors.New("wecom websocket is not connected") } c.writeMu.Lock() defer c.writeMu.Unlock() @@ -251,7 +263,7 @@ func (c *WSClient) Send(ctx context.Context, frame WSFrame) error { func (c *WSClient) SendWithAck(ctx context.Context, frame WSFrame) (WSFrame, error) { reqID := strings.TrimSpace(frame.Headers.ReqID) if reqID == "" { - return WSFrame{}, fmt.Errorf("req_id is required") + return WSFrame{}, errors.New("req_id is required") } wait := make(chan wsAck, 1) c.waitMu.Lock() @@ -293,7 +305,7 @@ func (c *WSClient) Close() error { conn := c.conn c.conn = nil c.connMu.Unlock() - c.failAllWaiters(fmt.Errorf("wecom websocket client closed")) + c.failAllWaiters(errors.New("wecom websocket client closed")) if conn == nil { return nil } diff --git a/internal/channel/adapters/wecom/ws_client_test.go b/internal/channel/adapters/wecom/ws_client_test.go index 7100dc41..3fd91618 100644 --- a/internal/channel/adapters/wecom/ws_client_test.go +++ b/internal/channel/adapters/wecom/ws_client_test.go @@ -3,6 +3,7 @@ package wecom import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -25,7 +26,7 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) { if err != nil { return } - defer conn.Close() + defer func() { _ = conn.Close() }() n := connCount.Add(1) var subscribeFrame WSFrame @@ -64,11 +65,11 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) { wsURL := "ws" + strings.TrimPrefix(server.URL, "http") client := NewWSClient(WSClientOptions{ - URL: wsURL, - AckTimeout: 200 * time.Millisecond, - HeartbeatInterval: 10 * time.Second, - ReconnectBaseDelay: 10 * time.Millisecond, - ReconnectMaxDelay: 20 * time.Millisecond, + URL: wsURL, + AckTimeout: 200 * time.Millisecond, + HeartbeatInterval: 10 * time.Second, + ReconnectBaseDelay: 10 * time.Millisecond, + ReconnectMaxDelay: 20 * time.Millisecond, MaxReconnectAttempts: 5, }) @@ -76,7 +77,7 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) { defer cancel() runErrCh := make(chan error, 1) go func() { - runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, func(context.Context, WSFrame) error { + runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Credential: "sec"}, func(context.Context, WSFrame) error { cancel() return nil }) @@ -89,7 +90,7 @@ func TestWSClientRun_ReconnectsAfterDisconnect(t *testing.T) { } select { case err := <-runErrCh: - if err == nil || err != context.Canceled { + if err == nil || !errors.Is(err, context.Canceled) { t.Fatalf("unexpected run error: %v", err) } case <-time.After(2 * time.Second): @@ -112,7 +113,7 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) { if err != nil { return } - defer conn.Close() + defer func() { _ = conn.Close() }() n := connCount.Add(1) var subscribeFrame WSFrame @@ -153,7 +154,7 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) { defer cancel() runErrCh := make(chan error, 1) go func() { - runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Secret: "sec"}, nil) + runErrCh <- client.Run(ctx, AuthCredentials{BotID: "bot", Credential: "sec"}, nil) }() select { @@ -165,7 +166,7 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) { cancel() select { case err := <-runErrCh: - if err == nil || err != context.Canceled { + if err == nil || !errors.Is(err, context.Canceled) { t.Fatalf("unexpected run error: %v", err) } case <-time.After(2 * time.Second): @@ -175,4 +176,3 @@ func TestWSClientRun_HeartbeatDoesNotRequireAck(t *testing.T) { t.Fatalf("expected at least one session, got %d", connCount.Load()) } } -