mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
599bfb5ca8
Fix WeCom adapter typos and strict Go lint findings (gosec/bodyclose/errcheck/revive) while keeping runtime behavior unchanged.
409 lines
9.0 KiB
Go
409 lines
9.0 KiB
Go
package wecom
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type WSClientOptions struct {
|
|
URL string
|
|
Dialer *websocket.Dialer
|
|
Logger *slog.Logger
|
|
AckTimeout time.Duration
|
|
WriteTimeout time.Duration
|
|
ReadTimeout time.Duration
|
|
HeartbeatInterval time.Duration
|
|
ReconnectBaseDelay time.Duration
|
|
ReconnectMaxDelay time.Duration
|
|
MaxReconnectAttempts int
|
|
}
|
|
|
|
type WSClient struct {
|
|
opts WSClientOptions
|
|
logger *slog.Logger
|
|
|
|
writeMu sync.Mutex
|
|
waitMu sync.Mutex
|
|
connMu sync.RWMutex
|
|
|
|
conn *websocket.Conn
|
|
waiters map[string]chan wsAck
|
|
closed bool
|
|
}
|
|
|
|
type wsAck struct {
|
|
frame WSFrame
|
|
err error
|
|
}
|
|
|
|
func NewWSClient(opts WSClientOptions) *WSClient {
|
|
if strings.TrimSpace(opts.URL) == "" {
|
|
opts.URL = defaultWSURL
|
|
}
|
|
if opts.Logger == nil {
|
|
opts.Logger = slog.Default()
|
|
}
|
|
if opts.AckTimeout <= 0 {
|
|
opts.AckTimeout = 8 * time.Second
|
|
}
|
|
if opts.WriteTimeout <= 0 {
|
|
opts.WriteTimeout = 8 * time.Second
|
|
}
|
|
if opts.ReadTimeout <= 0 {
|
|
opts.ReadTimeout = 70 * time.Second
|
|
}
|
|
if opts.HeartbeatInterval <= 0 {
|
|
opts.HeartbeatInterval = 30 * time.Second
|
|
}
|
|
if opts.ReconnectBaseDelay <= 0 {
|
|
opts.ReconnectBaseDelay = 1 * time.Second
|
|
}
|
|
if opts.ReconnectMaxDelay <= 0 {
|
|
opts.ReconnectMaxDelay = 30 * time.Second
|
|
}
|
|
return &WSClient{
|
|
opts: opts,
|
|
logger: opts.Logger.With(slog.String("component", "wecom_ws_client")),
|
|
waiters: make(map[string]chan wsAck),
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) Run(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
|
|
if err := auth.Validate(); err != nil {
|
|
return err
|
|
}
|
|
attempt := 0
|
|
for {
|
|
err := c.runSession(ctx, auth, onFrame)
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
if c.isClosed() {
|
|
return nil
|
|
}
|
|
if c.opts.MaxReconnectAttempts >= 0 && attempt >= c.opts.MaxReconnectAttempts {
|
|
if err == nil {
|
|
return errors.New("wecom websocket reconnect attempts exceeded")
|
|
}
|
|
return err
|
|
}
|
|
delay := c.backoff(attempt)
|
|
attempt++
|
|
c.logger.Warn("wecom websocket session ended; reconnecting",
|
|
slog.Int("attempt", attempt),
|
|
slog.Duration("delay", delay),
|
|
slog.Any("error", err),
|
|
)
|
|
timer := time.NewTimer(delay)
|
|
select {
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) runSession(ctx context.Context, auth AuthCredentials, onFrame func(context.Context, WSFrame) error) error {
|
|
conn, resp, err := c.dial(ctx)
|
|
if err != nil {
|
|
if resp != nil && resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
return err
|
|
}
|
|
if resp != nil && resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
c.setConn(conn)
|
|
sessionCtx, cancel := context.WithCancel(ctx)
|
|
defer func() {
|
|
cancel()
|
|
c.clearConn()
|
|
c.failAllWaiters(errors.New("wecom websocket disconnected"))
|
|
}()
|
|
|
|
readErrCh := make(chan error, 1)
|
|
go c.readLoop(sessionCtx, onFrame, readErrCh)
|
|
|
|
if err := c.authenticate(sessionCtx, auth); err != nil {
|
|
_ = conn.Close()
|
|
return err
|
|
}
|
|
c.logger.Info("wecom websocket authenticated")
|
|
|
|
go c.heartbeatLoop(sessionCtx)
|
|
|
|
select {
|
|
case <-sessionCtx.Done():
|
|
_ = conn.Close()
|
|
return sessionCtx.Err()
|
|
case err := <-readErrCh:
|
|
_ = conn.Close()
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) dial(ctx context.Context) (*websocket.Conn, *http.Response, error) {
|
|
dialer := c.opts.Dialer
|
|
if dialer == nil {
|
|
dialer = websocket.DefaultDialer
|
|
}
|
|
return dialer.DialContext(ctx, c.opts.URL, nil)
|
|
}
|
|
|
|
func (c *WSClient) authenticate(ctx context.Context, auth AuthCredentials) error {
|
|
frame, err := BuildFrame(WSCmdSubscribe, NewReqID(WSCmdSubscribe), map[string]string{
|
|
"bot_id": strings.TrimSpace(auth.BotID),
|
|
"secret": strings.TrimSpace(auth.Credential),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ack, err := c.SendWithAck(ctx, frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if ack.ErrCode != 0 {
|
|
return fmt.Errorf("wecom subscribe failed: %s (code: %d)", strings.TrimSpace(ack.ErrMsg), ack.ErrCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *WSClient) heartbeatLoop(ctx context.Context) {
|
|
ticker := time.NewTicker(c.opts.HeartbeatInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
frame, err := BuildFrame(WSCmdHeartbeat, NewReqID(WSCmdHeartbeat), nil)
|
|
if err != nil {
|
|
c.logger.Error("build heartbeat frame failed", slog.Any("error", err))
|
|
continue
|
|
}
|
|
err = c.Send(ctx, frame)
|
|
if err != nil {
|
|
c.logger.Warn("wecom websocket heartbeat failed", slog.Any("error", err))
|
|
if conn := c.getConn(); conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) readLoop(ctx context.Context, onFrame func(context.Context, WSFrame) error, errCh chan<- error) {
|
|
conn := c.getConn()
|
|
if conn == nil {
|
|
errCh <- errors.New("wecom websocket connection not ready")
|
|
return
|
|
}
|
|
for {
|
|
if c.opts.ReadTimeout > 0 {
|
|
_ = conn.SetReadDeadline(time.Now().Add(c.opts.ReadTimeout))
|
|
}
|
|
_, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
var frame WSFrame
|
|
if err := json.Unmarshal(payload, &frame); err != nil {
|
|
c.logger.Warn("decode websocket frame failed", slog.Any("error", err))
|
|
continue
|
|
}
|
|
if c.dispatchAck(frame) {
|
|
continue
|
|
}
|
|
if onFrame == nil {
|
|
continue
|
|
}
|
|
if err := onFrame(ctx, frame); err != nil {
|
|
c.logger.Warn("wecom onFrame callback returned error", slog.Any("error", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) Send(ctx context.Context, frame WSFrame) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
if strings.TrimSpace(frame.Headers.ReqID) == "" {
|
|
return errors.New("req_id is required")
|
|
}
|
|
conn := c.getConn()
|
|
if conn == nil {
|
|
return errors.New("wecom websocket is not connected")
|
|
}
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
if c.opts.WriteTimeout > 0 {
|
|
_ = conn.SetWriteDeadline(time.Now().Add(c.opts.WriteTimeout))
|
|
}
|
|
if err := conn.WriteJSON(frame); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *WSClient) SendWithAck(ctx context.Context, frame WSFrame) (WSFrame, error) {
|
|
reqID := strings.TrimSpace(frame.Headers.ReqID)
|
|
if reqID == "" {
|
|
return WSFrame{}, errors.New("req_id is required")
|
|
}
|
|
wait := make(chan wsAck, 1)
|
|
c.waitMu.Lock()
|
|
c.waiters[reqID] = wait
|
|
c.waitMu.Unlock()
|
|
defer func() {
|
|
c.waitMu.Lock()
|
|
delete(c.waiters, reqID)
|
|
c.waitMu.Unlock()
|
|
}()
|
|
if err := c.Send(ctx, frame); err != nil {
|
|
return WSFrame{}, err
|
|
}
|
|
timeout := c.opts.AckTimeout
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
remaining := time.Until(deadline)
|
|
if remaining > 0 && remaining < timeout {
|
|
timeout = remaining
|
|
}
|
|
}
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
select {
|
|
case <-ctx.Done():
|
|
return WSFrame{}, ctx.Err()
|
|
case <-timer.C:
|
|
return WSFrame{}, fmt.Errorf("wait websocket ack timeout for req_id=%s", reqID)
|
|
case ack := <-wait:
|
|
if ack.err != nil {
|
|
return WSFrame{}, ack.err
|
|
}
|
|
return ack.frame, nil
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) Close() error {
|
|
c.connMu.Lock()
|
|
c.closed = true
|
|
conn := c.conn
|
|
c.conn = nil
|
|
c.connMu.Unlock()
|
|
c.failAllWaiters(errors.New("wecom websocket client closed"))
|
|
if conn == nil {
|
|
return nil
|
|
}
|
|
return conn.Close()
|
|
}
|
|
|
|
func (c *WSClient) Reply(ctx context.Context, reqID string, cmd string, body any) (WSFrame, error) {
|
|
frame, err := BuildFrame(cmd, reqID, body)
|
|
if err != nil {
|
|
return WSFrame{}, err
|
|
}
|
|
// WeCom callback reply commands are triggered by inbound callback req_id and may
|
|
// not always return an explicit ACK frame in production. Waiting for ACK here can
|
|
// cause false timeouts even when the platform accepts the reply.
|
|
if isRespondCommand(cmd) {
|
|
if err := c.Send(ctx, frame); err != nil {
|
|
return WSFrame{}, err
|
|
}
|
|
return WSFrame{}, nil
|
|
}
|
|
return c.SendWithAck(ctx, frame)
|
|
}
|
|
|
|
func isRespondCommand(cmd string) bool {
|
|
switch strings.TrimSpace(cmd) {
|
|
case WSCmdRespond, WSCmdRespondWelcome, WSCmdRespondUpdate:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) dispatchAck(frame WSFrame) bool {
|
|
reqID := strings.TrimSpace(frame.Headers.ReqID)
|
|
if reqID == "" {
|
|
return false
|
|
}
|
|
c.waitMu.Lock()
|
|
wait, ok := c.waiters[reqID]
|
|
c.waitMu.Unlock()
|
|
if !ok {
|
|
return false
|
|
}
|
|
select {
|
|
case wait <- wsAck{frame: frame}:
|
|
default:
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c *WSClient) failAllWaiters(cause error) {
|
|
c.waitMu.Lock()
|
|
defer c.waitMu.Unlock()
|
|
for id, wait := range c.waiters {
|
|
delete(c.waiters, id)
|
|
select {
|
|
case wait <- wsAck{err: cause}:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) backoff(attempt int) time.Duration {
|
|
if attempt < 0 {
|
|
attempt = 0
|
|
}
|
|
delay := c.opts.ReconnectBaseDelay << attempt
|
|
if delay > c.opts.ReconnectMaxDelay {
|
|
return c.opts.ReconnectMaxDelay
|
|
}
|
|
return delay
|
|
}
|
|
|
|
func (c *WSClient) setConn(conn *websocket.Conn) {
|
|
c.connMu.Lock()
|
|
defer c.connMu.Unlock()
|
|
c.conn = conn
|
|
}
|
|
|
|
func (c *WSClient) clearConn() {
|
|
c.connMu.Lock()
|
|
conn := c.conn
|
|
c.conn = nil
|
|
c.connMu.Unlock()
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func (c *WSClient) getConn() *websocket.Conn {
|
|
c.connMu.RLock()
|
|
defer c.connMu.RUnlock()
|
|
return c.conn
|
|
}
|
|
|
|
func (c *WSClient) isClosed() bool {
|
|
c.connMu.RLock()
|
|
defer c.connMu.RUnlock()
|
|
return c.closed
|
|
}
|