Files
Memoh/internal/channel/adapters/wecom/ws_client.go
T
BBQ 599bfb5ca8 fix(wecom): pass lint and typo checks
Fix WeCom adapter typos and strict Go lint findings (gosec/bodyclose/errcheck/revive) while keeping runtime behavior unchanged.
2026-03-11 02:14:00 +08:00

409 lines
9.0 KiB
Go

package wecom
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
type WSClientOptions struct {
URL string
Dialer *websocket.Dialer
Logger *slog.Logger
AckTimeout time.Duration
WriteTimeout time.Duration
ReadTimeout time.Duration
HeartbeatInterval time.Duration
ReconnectBaseDelay time.Duration
ReconnectMaxDelay time.Duration
MaxReconnectAttempts int
}
type WSClient struct {
opts WSClientOptions
logger *slog.Logger
writeMu sync.Mutex
waitMu sync.Mutex
connMu sync.RWMutex
conn *websocket.Conn
waiters map[string]chan wsAck
closed bool
}
type wsAck struct {
frame WSFrame
err error
}
func NewWSClient(opts WSClientOptions) *WSClient {
if strings.TrimSpace(opts.URL) == "" {
opts.URL = defaultWSURL
}
if opts.Logger == nil {
opts.Logger = slog.Default()
}
if opts.AckTimeout <= 0 {
opts.AckTimeout = 8 * time.Second
}
if opts.WriteTimeout <= 0 {
opts.WriteTimeout = 8 * time.Second
}
if opts.ReadTimeout <= 0 {
opts.ReadTimeout = 70 * time.Second
}
if opts.HeartbeatInterval <= 0 {
opts.HeartbeatInterval = 30 * time.Second
}
if opts.ReconnectBaseDelay <= 0 {
opts.ReconnectBaseDelay = 1 * time.Second
}
if opts.ReconnectMaxDelay <= 0 {
opts.ReconnectMaxDelay = 30 * time.Second
}
return &WSClient{
opts: opts,
logger: opts.Logger.With(slog.String("component", "wecom_ws_client")),
waiters: make(map[string]chan wsAck),
}
}
func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
if err := auth.Validate(); err != nil {
return err
}
attempt := 0
for {
err := c.runSession(ctx, auth, onFrame)
if ctx.Err() != nil {
return ctx.Err()
}
if c.isClosed() {
return nil
}
if c.opts.MaxReconnectAttempts >= 0 && attempt >= c.opts.MaxReconnectAttempts {
if err == nil {
return errors.New("wecom websocket reconnect attempts exceeded")
}
return err
}
delay := c.backoff(attempt)
attempt++
c.logger.Warn("wecom websocket session ended; reconnecting",
slog.Int("attempt", attempt),
slog.Duration("delay", delay),
slog.Any("error", err),
)
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-timer.C:
}
}
}
func (c *WSClient) runSession(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
conn, resp, err := c.dial(ctx)
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
return err
}
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
c.setConn(conn)
sessionCtx, cancel := context.WithCancel(ctx)
defer func() {
cancel()
c.clearConn()
c.failAllWaiters(errors.New("wecom websocket disconnected"))
}()
readErrCh := make(chan error, 1)
go c.readLoop(sessionCtx, onFrame, readErrCh)
if err := c.authenticate(sessionCtx, auth); err != nil {
_ = conn.Close()
return err
}
c.logger.Info("wecom websocket authenticated")
go c.heartbeatLoop(sessionCtx)
select {
case <-sessionCtx.Done():
_ = conn.Close()
return sessionCtx.Err()
case err := <-readErrCh:
_ = conn.Close()
return err
}
}
func (c *WSClient) dial(ctx context.Context) (*websocket.Conn, *http.Response, error) {
dialer := c.opts.Dialer
if dialer == nil {
dialer = websocket.DefaultDialer
}
return dialer.DialContext(ctx, c.opts.URL, nil)
}
func (c *WSClient) authenticate(ctx context.Context, auth AuthCredentials) error {
frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), map[string]string{
"bot_id": strings.TrimSpace(auth.BotID),
"secret": strings.TrimSpace(auth.Credential),
})
if err != nil {
return err
}
ack, err := c.SendWithAck(ctx, frame)
if err != nil {
return err
}
if ack.ErrCode != 0 {
return fmt.Errorf("wecom subscribe failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
}
return nil
}
func (c *WSClient) heartbeatLoop(ctx context.Context) {
ticker := time.NewTicker(c.opts.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
frame, err := BuildFrame(WSCmdHeartbeat, NewReqID(WSCmdHeartbeat), nil)
if err != nil {
c.logger.Error("build heartbeat frame failed", slog.Any("error", err))
continue
}
err = c.Send(ctx, frame)
if err != nil {
c.logger.Warn("wecom websocket heartbeat failed", slog.Any("error", err))
if conn := c.getConn(); conn != nil {
_ = conn.Close()
}
return
}
}
}
}
func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, WSFrame) error, errCh chan<- error) {
conn := c.getConn()
if conn == nil {
errCh <- errors.New("wecom websocket connection not ready")
return
}
for {
if c.opts.ReadTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout))
}
_, payload, err := conn.ReadMessage()
if err != nil {
errCh <- err
return
}
var frame WSFrame
if err := json.Unmarshal(payload, &frame); err != nil {
c.logger.Warn("decode websocket frame failed", slog.Any("error", err))
continue
}
if c.dispatchAck(frame) {
continue
}
if onFrame == nil {
continue
}
if err := onFrame(ctx, frame); err != nil {
c.logger.Warn("wecom onFrame callback returned error", slog.Any("error", err))
}
}
}
func (c *WSClient) Send(ctx context.Context, frame WSFrame) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if strings.TrimSpace(frame.Headers.ReqID) == "" {
return errors.New("req_id is required")
}
conn := c.getConn()
if conn == nil {
return errors.New("wecom websocket is not connected")
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.opts.WriteTimeout > 0 {
_ = conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout))
}
if err := conn.WriteJSON(frame); err != nil {
return err
}
return nil
}
func (c *WSClient) SendWithAck(ctx context.Context, frame WSFrame) (WSFrame, error) {
reqID := strings.TrimSpace(frame.Headers.ReqID)
if reqID == "" {
return WSFrame{}, errors.New("req_id is required")
}
wait := make(chan wsAck, 1)
c.waitMu.Lock()
c.waiters[reqID] = wait
c.waitMu.Unlock()
defer func() {
c.waitMu.Lock()
delete(c.waiters, reqID)
c.waitMu.Unlock()
}()
if err := c.Send(ctx, frame); err != nil {
return WSFrame{}, err
}
timeout := c.opts.AckTimeout
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
if remaining > 0 && remaining < timeout {
timeout = remaining
}
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-ctx.Done():
return WSFrame{}, ctx.Err()
case <-timer.C:
return WSFrame{}, fmt.Errorf("wait websocket ack timeout for req_id=%s", reqID)
case ack := <-wait:
if ack.err != nil {
return WSFrame{}, ack.err
}
return ack.frame, nil
}
}
func (c *WSClient) Close() error {
c.connMu.Lock()
c.closed = true
conn := c.conn
c.conn = nil
c.connMu.Unlock()
c.failAllWaiters(errors.New("wecom websocket client closed"))
if conn == nil {
return nil
}
return conn.Close()
}
func (c *WSClient) Reply(ctx context.Context, reqID string, cmd string, body any) (WSFrame, error) {
frame, err := BuildFrame(cmd, reqID, body)
if err != nil {
return WSFrame{}, err
}
// WeCom callback reply commands are triggered by inbound callback req_id and may
// not always return an explicit ACK frame in production. Waiting for ACK here can
// cause false timeouts even when the platform accepts the reply.
if isRespondCommand(cmd) {
if err := c.Send(ctx, frame); err != nil {
return WSFrame{}, err
}
return WSFrame{}, nil
}
return c.SendWithAck(ctx, frame)
}
func isRespondCommand(cmd string) bool {
switch strings.TrimSpace(cmd) {
case WSCmdRespond, WSCmdRespondWelcome, WSCmdRespondUpdate:
return true
default:
return false
}
}
func (c *WSClient) dispatchAck(frame WSFrame) bool {
reqID := strings.TrimSpace(frame.Headers.ReqID)
if reqID == "" {
return false
}
c.waitMu.Lock()
wait, ok := c.waiters[reqID]
c.waitMu.Unlock()
if !ok {
return false
}
select {
case wait <- wsAck{frame: frame}:
default:
}
return true
}
func (c *WSClient) failAllWaiters(cause error) {
c.waitMu.Lock()
defer c.waitMu.Unlock()
for id, wait := range c.waiters {
delete(c.waiters, id)
select {
case wait <- wsAck{err: cause}:
default:
}
}
}
func (c *WSClient) backoff(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := c.opts.ReconnectBaseDelay << attempt
if delay > c.opts.ReconnectMaxDelay {
return c.opts.ReconnectMaxDelay
}
return delay
}
func (c *WSClient) setConn(conn *websocket.Conn) {
c.connMu.Lock()
defer c.connMu.Unlock()
c.conn = conn
}
func (c *WSClient) clearConn() {
c.connMu.Lock()
conn := c.conn
c.conn = nil
c.connMu.Unlock()
if conn != nil {
_ = conn.Close()
}
}
func (c *WSClient) getConn() *websocket.Conn {
c.connMu.RLock()
defer c.connMu.RUnlock()
return c.conn
}
func (c *WSClient) isClosed() bool {
c.connMu.RLock()
defer c.connMu.RUnlock()
return c.closed
}