mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
401 lines
10 KiB
Go
401 lines
10 KiB
Go
package qq
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/memohai/memoh/internal/channel"
|
|
)
|
|
|
|
type qqClient struct {
|
|
appID string
|
|
clientSecret string
|
|
httpClient *http.Client
|
|
logger interface {
|
|
Debug(string, ...any)
|
|
}
|
|
apiBaseURL string
|
|
tokenURL string
|
|
|
|
tokenMu sync.Mutex
|
|
token string
|
|
expiresAt time.Time
|
|
|
|
msgSeqMu sync.Mutex
|
|
msgSeq map[string]int
|
|
}
|
|
|
|
func (c *qqClient) matches(cfg Config) bool {
|
|
return c.appID == cfg.AppID && c.clientSecret == cfg.AppSecret
|
|
}
|
|
|
|
func (c *qqClient) clearToken() {
|
|
c.tokenMu.Lock()
|
|
defer c.tokenMu.Unlock()
|
|
c.token = ""
|
|
c.expiresAt = time.Time{}
|
|
}
|
|
|
|
func (c *qqClient) accessToken(ctx context.Context) (string, error) {
|
|
c.tokenMu.Lock()
|
|
defer c.tokenMu.Unlock()
|
|
|
|
if c.token != "" && time.Now().Before(c.expiresAt.Add(-5*time.Minute)) {
|
|
return c.token, nil
|
|
}
|
|
|
|
payload := map[string]string{
|
|
"appId": c.appID,
|
|
"clientSecret": c.clientSecret,
|
|
}
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
u, err := url.Parse(c.tokenURL)
|
|
if err != nil || (u.Scheme != "https" && !isLocalhost(u.Host)) {
|
|
return "", fmt.Errorf("invalid token url: %s", c.tokenURL)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.tokenURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(req) //nolint:gosec // token URL is validated to https or localhost above
|
|
if err != nil {
|
|
return "", fmt.Errorf("qq token request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
raw, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("qq token read: %w", err)
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return "", fmt.Errorf("qq token request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
|
}
|
|
|
|
var result map[string]json.RawMessage
|
|
if err := json.Unmarshal(raw, &result); err != nil {
|
|
return "", fmt.Errorf("qq token decode: %w", err)
|
|
}
|
|
tokenBytes, ok := result["access_token"]
|
|
if !ok {
|
|
return "", errors.New("qq token response missing access_token")
|
|
}
|
|
var token string
|
|
if err := json.Unmarshal(tokenBytes, &token); err != nil {
|
|
return "", fmt.Errorf("qq token decode: %w", err)
|
|
}
|
|
if strings.TrimSpace(token) == "" {
|
|
return "", errors.New("qq token response missing access_token")
|
|
}
|
|
expiresIn := parseQQExpiresIn(result["expires_in"])
|
|
if expiresIn <= 0 {
|
|
expiresIn = 7200
|
|
}
|
|
c.token = token
|
|
c.expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
|
channel.SetIMErrorSecrets("qq-token:"+c.appID, c.clientSecret, c.token)
|
|
return c.token, nil
|
|
}
|
|
|
|
func (c *qqClient) gatewayURL(ctx context.Context) (string, error) {
|
|
var result struct {
|
|
URL string `json:"url"`
|
|
}
|
|
if err := c.doJSON(ctx, http.MethodGet, "/gateway", nil, &result); err != nil {
|
|
return "", err
|
|
}
|
|
if strings.TrimSpace(result.URL) == "" {
|
|
return "", errors.New("qq gateway response missing url")
|
|
}
|
|
return result.URL, nil
|
|
}
|
|
|
|
func (c *qqClient) nextMsgSeq(replyTo string) int {
|
|
if strings.TrimSpace(replyTo) == "" {
|
|
return 1
|
|
}
|
|
c.msgSeqMu.Lock()
|
|
defer c.msgSeqMu.Unlock()
|
|
|
|
next := c.msgSeq[replyTo] + 1
|
|
c.msgSeq[replyTo] = next
|
|
if len(c.msgSeq) > 1024 {
|
|
for key := range c.msgSeq {
|
|
delete(c.msgSeq, key)
|
|
if len(c.msgSeq) <= 512 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return next
|
|
}
|
|
|
|
func parseQQExpiresIn(raw json.RawMessage) int {
|
|
trimmed := strings.TrimSpace(string(raw))
|
|
if trimmed == "" || trimmed == "null" {
|
|
return 0
|
|
}
|
|
|
|
var numeric int
|
|
if err := json.Unmarshal(raw, &numeric); err == nil {
|
|
return numeric
|
|
}
|
|
|
|
var text string
|
|
if err := json.Unmarshal(raw, &text); err == nil {
|
|
value, err := strconv.Atoi(strings.TrimSpace(text))
|
|
if err == nil {
|
|
return value
|
|
}
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
func (c *qqClient) doJSON(ctx context.Context, method, path string, payload any, out any) error {
|
|
return c.doJSONWithRetry(ctx, method, c.apiBaseURL+path, payload, out, true)
|
|
}
|
|
|
|
func (c *qqClient) doJSONWithRetry(ctx context.Context, method, url string, payload any, out any, auth bool) error {
|
|
var lastErr error
|
|
for attempt := 0; attempt < 2; attempt++ {
|
|
lastErr = c.doJSONOnce(ctx, method, url, payload, out, auth)
|
|
if lastErr == nil {
|
|
return nil
|
|
}
|
|
if !auth || !strings.Contains(lastErr.Error(), "status=401") {
|
|
return lastErr
|
|
}
|
|
c.clearToken()
|
|
}
|
|
return lastErr
|
|
}
|
|
|
|
func (c *qqClient) doJSONOnce(ctx context.Context, method, requestURL string, payload any, out any, auth bool) error {
|
|
var body io.Reader
|
|
if payload != nil {
|
|
encoded, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
body = bytes.NewReader(encoded)
|
|
}
|
|
|
|
u, err := url.Parse(requestURL)
|
|
if err != nil || (u.Scheme != "https" && !isLocalhost(u.Host)) {
|
|
return fmt.Errorf("invalid api url: %s", requestURL)
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, method, requestURL, body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
if auth {
|
|
token, err := c.accessToken(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Authorization", "QQBot "+token)
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req) //nolint:gosec // requestURL is validated to https or localhost above
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
raw, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return fmt.Errorf(
|
|
"qq api request failed: method=%s url=%s status=%d body=%s",
|
|
method,
|
|
requestURL,
|
|
resp.StatusCode,
|
|
strings.TrimSpace(string(raw)),
|
|
)
|
|
}
|
|
if out == nil || len(raw) == 0 {
|
|
return nil
|
|
}
|
|
if err := json.Unmarshal(raw, out); err != nil {
|
|
return fmt.Errorf("qq api decode failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type qqMessageResponse struct {
|
|
ID string `json:"id"`
|
|
Timestamp any `json:"timestamp"`
|
|
}
|
|
|
|
type qqUploadResponse struct {
|
|
FileUUID string `json:"file_uuid"`
|
|
FileInfo string `json:"file_info"`
|
|
TTL int `json:"ttl"`
|
|
}
|
|
|
|
func (c *qqClient) sendText(ctx context.Context, target qqTarget, text string, replyTo string, markdown bool) error {
|
|
text = strings.TrimSpace(text)
|
|
if text == "" {
|
|
return nil
|
|
}
|
|
|
|
switch target.Kind {
|
|
case qqTargetC2C:
|
|
if replyTo == "" {
|
|
return c.sendProactive(ctx, "/v2/users/"+target.ID+"/messages", text, markdown)
|
|
}
|
|
body := buildReplyTextBody(text, replyTo, c.nextMsgSeq(replyTo), markdown)
|
|
return c.doJSON(ctx, http.MethodPost, "/v2/users/"+target.ID+"/messages", body, &qqMessageResponse{})
|
|
case qqTargetGroup:
|
|
if replyTo == "" {
|
|
return c.sendProactive(ctx, "/v2/groups/"+target.ID+"/messages", text, markdown)
|
|
}
|
|
body := buildReplyTextBody(text, replyTo, c.nextMsgSeq(replyTo), markdown)
|
|
return c.doJSON(ctx, http.MethodPost, "/v2/groups/"+target.ID+"/messages", body, &qqMessageResponse{})
|
|
case qqTargetChannel:
|
|
body := map[string]any{"content": text}
|
|
if strings.TrimSpace(replyTo) != "" {
|
|
replyID := strings.TrimSpace(replyTo)
|
|
body["msg_id"] = replyID
|
|
body["message_reference"] = map[string]any{"message_id": replyID}
|
|
}
|
|
return c.doJSON(ctx, http.MethodPost, "/channels/"+target.ID+"/messages", body, &qqMessageResponse{})
|
|
default:
|
|
return fmt.Errorf("unsupported qq target kind: %s", target.Kind)
|
|
}
|
|
}
|
|
|
|
func (c *qqClient) sendProactive(ctx context.Context, path, text string, markdown bool) error {
|
|
body := map[string]any{}
|
|
if markdown {
|
|
body["markdown"] = map[string]any{"content": text}
|
|
body["msg_type"] = 2
|
|
} else {
|
|
body["content"] = text
|
|
body["msg_type"] = 0
|
|
}
|
|
return c.doJSON(ctx, http.MethodPost, path, body, &qqMessageResponse{})
|
|
}
|
|
|
|
func buildReplyTextBody(text, replyTo string, seq int, markdown bool) map[string]any {
|
|
body := map[string]any{
|
|
"msg_id": strings.TrimSpace(replyTo),
|
|
"msg_seq": seq,
|
|
}
|
|
if markdown {
|
|
body["markdown"] = map[string]any{"content": text}
|
|
body["msg_type"] = 2
|
|
} else {
|
|
body["content"] = text
|
|
body["msg_type"] = 0
|
|
}
|
|
return body
|
|
}
|
|
|
|
func (c *qqClient) sendInputHint(ctx context.Context, openID, replyTo string) error {
|
|
if strings.TrimSpace(openID) == "" || strings.TrimSpace(replyTo) == "" {
|
|
return nil
|
|
}
|
|
body := map[string]any{
|
|
"msg_type": 6,
|
|
"input_notify": map[string]any{
|
|
"input_type": 1,
|
|
"input_second": 60,
|
|
},
|
|
"msg_seq": c.nextMsgSeq(replyTo),
|
|
"msg_id": strings.TrimSpace(replyTo),
|
|
}
|
|
return c.doJSON(ctx, http.MethodPost, "/v2/users/"+openID+"/messages", body, nil)
|
|
}
|
|
|
|
func (c *qqClient) uploadMedia(ctx context.Context, target qqTarget, fileType int, rawBase64, fileName string) (string, error) {
|
|
rawBase64 = strings.TrimSpace(rawBase64)
|
|
if rawBase64 == "" {
|
|
return "", errors.New("qq upload requires file_data")
|
|
}
|
|
body := map[string]any{
|
|
"file_type": fileType,
|
|
"srv_send_msg": false,
|
|
}
|
|
body["file_data"] = rawBase64
|
|
if fileType == qqMediaTypeFile && strings.TrimSpace(fileName) != "" {
|
|
body["file_name"] = strings.TrimSpace(fileName)
|
|
}
|
|
|
|
var path string
|
|
switch target.Kind {
|
|
case qqTargetC2C:
|
|
path = "/v2/users/" + target.ID + "/files"
|
|
case qqTargetGroup:
|
|
path = "/v2/groups/" + target.ID + "/files"
|
|
default:
|
|
return "", fmt.Errorf("qq upload not supported for target kind: %s", target.Kind)
|
|
}
|
|
|
|
var result qqUploadResponse
|
|
if err := c.doJSON(ctx, http.MethodPost, path, body, &result); err != nil {
|
|
return "", err
|
|
}
|
|
if strings.TrimSpace(result.FileInfo) == "" {
|
|
return "", errors.New("qq upload response missing file_info")
|
|
}
|
|
return result.FileInfo, nil
|
|
}
|
|
|
|
func (c *qqClient) sendMedia(ctx context.Context, target qqTarget, fileInfo, replyTo, content string) error {
|
|
body := map[string]any{
|
|
"msg_type": 7,
|
|
"media": map[string]any{
|
|
"file_info": fileInfo,
|
|
},
|
|
}
|
|
if strings.TrimSpace(content) != "" {
|
|
body["content"] = strings.TrimSpace(content)
|
|
}
|
|
if strings.TrimSpace(replyTo) != "" {
|
|
body["msg_id"] = strings.TrimSpace(replyTo)
|
|
body["msg_seq"] = c.nextMsgSeq(replyTo)
|
|
} else {
|
|
body["msg_seq"] = 1
|
|
}
|
|
|
|
switch target.Kind {
|
|
case qqTargetC2C:
|
|
return c.doJSON(ctx, http.MethodPost, "/v2/users/"+target.ID+"/messages", body, &qqMessageResponse{})
|
|
case qqTargetGroup:
|
|
return c.doJSON(ctx, http.MethodPost, "/v2/groups/"+target.ID+"/messages", body, &qqMessageResponse{})
|
|
default:
|
|
return fmt.Errorf("qq media send not supported for target kind: %s", target.Kind)
|
|
}
|
|
}
|
|
|
|
func isLocalhost(host string) bool {
|
|
host = strings.ToLower(host)
|
|
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
|
return true
|
|
}
|
|
if strings.HasPrefix(host, "127.0.0.1:") || strings.HasPrefix(host, "[::1]:") || strings.HasPrefix(host, "localhost:") {
|
|
return true
|
|
}
|
|
return false
|
|
}
|