Files
Memoh/internal/mcp/providers/web/provider.go
T
2026-02-26 15:54:14 +08:00

409 lines
12 KiB
Go

package web
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
mcpgw "github.com/memohai/memoh/internal/mcp"
"github.com/memohai/memoh/internal/searchproviders"
"github.com/memohai/memoh/internal/settings"
)
const (
toolWebSearch = "web_search"
)
type Executor struct {
logger *slog.Logger
settings *settings.Service
searchProviders *searchproviders.Service
}
func NewExecutor(log *slog.Logger, settingsSvc *settings.Service, searchSvc *searchproviders.Service) *Executor {
if log == nil {
log = slog.Default()
}
return &Executor{
logger: log.With(slog.String("provider", "web_tool")),
settings: settingsSvc,
searchProviders: searchSvc,
}
}
func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) {
if p.settings == nil || p.searchProviders == nil {
return []mcpgw.ToolDescriptor{}, nil
}
return []mcpgw.ToolDescriptor{
{
Name: toolWebSearch,
Description: "Search web results via configured search provider.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string", "description": "Search query"},
"count": map[string]any{"type": "integer", "description": "Number of results, default 5"},
},
"required": []string{"query"},
},
},
}, nil
}
func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) {
if p.settings == nil || p.searchProviders == nil {
return mcpgw.BuildToolErrorResult("web tools are not available"), nil
}
botID := strings.TrimSpace(session.BotID)
if botID == "" {
return mcpgw.BuildToolErrorResult("bot_id is required"), nil
}
botSettings, err := p.settings.GetBot(ctx, botID)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
searchProviderID := strings.TrimSpace(botSettings.SearchProviderID)
if searchProviderID == "" {
return mcpgw.BuildToolErrorResult("search provider not configured for this bot"), nil
}
provider, err := p.searchProviders.GetRawByID(ctx, searchProviderID)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
switch toolName {
case toolWebSearch:
return p.callWebSearch(ctx, provider.Provider, provider.Config, arguments)
default:
return nil, mcpgw.ErrToolNotFound
}
}
func (p *Executor) callWebSearch(ctx context.Context, providerName string, configJSON []byte, arguments map[string]any) (map[string]any, error) {
query := strings.TrimSpace(mcpgw.StringArg(arguments, "query"))
if query == "" {
return mcpgw.BuildToolErrorResult("query is required"), nil
}
count := 5
if value, ok, err := mcpgw.IntArg(arguments, "count"); err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
} else if ok && value > 0 {
count = value
}
if count > 20 {
count = 20
}
switch strings.TrimSpace(providerName) {
case string(searchproviders.ProviderBrave):
return p.callBraveSearch(ctx, configJSON, query, count)
case string(searchproviders.ProviderBing):
return p.callBingSearch(ctx, configJSON, query, count)
case string(searchproviders.ProviderGoogle):
return p.callGoogleSearch(ctx, configJSON, query, count)
case string(searchproviders.ProviderTavily):
return p.callTavilySearch(ctx, configJSON, query, count)
default:
return mcpgw.BuildToolErrorResult("unsupported search provider"), nil
}
}
func (p *Executor) callBraveSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) {
cfg := parseConfig(configJSON)
endpoint := strings.TrimRight(firstNonEmpty(stringValue(cfg["base_url"]), "https://api.search.brave.com/res/v1/web/search"), "/")
reqURL, err := url.Parse(endpoint)
if err != nil {
return mcpgw.BuildToolErrorResult("invalid search provider base_url"), nil
}
params := reqURL.Query()
params.Set("q", query)
params.Set("count", fmt.Sprintf("%d", count))
reqURL.RawQuery = params.Encode()
timeout := parseTimeout(configJSON, 15*time.Second)
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
req.Header.Set("Accept", "application/json")
apiKey := stringValue(cfg["api_key"])
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("X-Subscription-Token", strings.TrimSpace(apiKey))
}
resp, err := client.Do(req)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return mcpgw.BuildToolErrorResult("search request failed"), nil
}
var raw struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &raw); err != nil {
return mcpgw.BuildToolErrorResult("invalid search response"), nil
}
results := make([]map[string]any, 0, len(raw.Web.Results))
for _, item := range raw.Web.Results {
results = append(results, map[string]any{
"title": item.Title,
"url": item.URL,
"description": item.Description,
})
}
return mcpgw.BuildToolSuccessResult(map[string]any{
"query": query,
"results": results,
}), nil
}
func (p *Executor) callBingSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) {
cfg := parseConfig(configJSON)
endpoint := strings.TrimRight(firstNonEmpty(stringValue(cfg["base_url"]), "https://api.bing.microsoft.com/v7.0/search"), "/")
reqURL, err := url.Parse(endpoint)
if err != nil {
return mcpgw.BuildToolErrorResult("invalid search provider base_url"), nil
}
params := reqURL.Query()
params.Set("q", query)
params.Set("count", fmt.Sprintf("%d", count))
reqURL.RawQuery = params.Encode()
timeout := parseTimeout(configJSON, 15*time.Second)
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
req.Header.Set("Accept", "application/json")
apiKey := stringValue(cfg["api_key"])
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("Ocp-Apim-Subscription-Key", strings.TrimSpace(apiKey))
}
resp, err := client.Do(req)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return mcpgw.BuildToolErrorResult("search request failed"), nil
}
var raw struct {
WebPages struct {
Value []struct {
Name string `json:"name"`
URL string `json:"url"`
Snippet string `json:"snippet"`
} `json:"value"`
} `json:"webPages"`
}
if err := json.Unmarshal(body, &raw); err != nil {
return mcpgw.BuildToolErrorResult("invalid search response"), nil
}
results := make([]map[string]any, 0, len(raw.WebPages.Value))
for _, item := range raw.WebPages.Value {
results = append(results, map[string]any{
"title": item.Name,
"url": item.URL,
"description": item.Snippet,
})
}
return mcpgw.BuildToolSuccessResult(map[string]any{
"query": query,
"results": results,
}), nil
}
func (p *Executor) callGoogleSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) {
cfg := parseConfig(configJSON)
endpoint := strings.TrimRight(firstNonEmpty(stringValue(cfg["base_url"]), "https://customsearch.googleapis.com/customsearch/v1"), "/")
reqURL, err := url.Parse(endpoint)
if err != nil {
return mcpgw.BuildToolErrorResult("invalid search provider base_url"), nil
}
cx := stringValue(cfg["cx"])
if cx == "" {
return mcpgw.BuildToolErrorResult("Google Custom Search requires cx (Search Engine ID)"), nil
}
if count > 10 {
count = 10
}
params := reqURL.Query()
params.Set("q", query)
params.Set("cx", cx)
params.Set("num", fmt.Sprintf("%d", count))
apiKey := stringValue(cfg["api_key"])
if strings.TrimSpace(apiKey) != "" {
params.Set("key", strings.TrimSpace(apiKey))
}
reqURL.RawQuery = params.Encode()
timeout := parseTimeout(configJSON, 15*time.Second)
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return mcpgw.BuildToolErrorResult("search request failed"), nil
}
var raw struct {
Items []struct {
Title string `json:"title"`
Link string `json:"link"`
Snippet string `json:"snippet"`
} `json:"items"`
}
if err := json.Unmarshal(body, &raw); err != nil {
return mcpgw.BuildToolErrorResult("invalid search response"), nil
}
results := make([]map[string]any, 0, len(raw.Items))
for _, item := range raw.Items {
results = append(results, map[string]any{
"title": item.Title,
"url": item.Link,
"description": item.Snippet,
})
}
return mcpgw.BuildToolSuccessResult(map[string]any{
"query": query,
"results": results,
}), nil
}
func (p *Executor) callTavilySearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) {
cfg := parseConfig(configJSON)
endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://api.tavily.com/search")
apiKey := stringValue(cfg["api_key"])
if apiKey == "" {
return mcpgw.BuildToolErrorResult("Tavily API key is required"), nil
}
payload, _ := json.Marshal(map[string]any{
"query": query,
"max_results": count,
})
timeout := parseTimeout(configJSON, 15*time.Second)
client := &http.Client{Timeout: timeout}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(payload))
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := client.Do(req)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return mcpgw.BuildToolErrorResult(err.Error()), nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return mcpgw.BuildToolErrorResult("search request failed"), nil
}
var raw struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
} `json:"results"`
}
if err := json.Unmarshal(body, &raw); err != nil {
return mcpgw.BuildToolErrorResult("invalid search response"), nil
}
results := make([]map[string]any, 0, len(raw.Results))
for _, item := range raw.Results {
results = append(results, map[string]any{
"title": item.Title,
"url": item.URL,
"description": item.Content,
})
}
return mcpgw.BuildToolSuccessResult(map[string]any{
"query": query,
"results": results,
}), nil
}
func parseTimeout(configJSON []byte, fallback time.Duration) time.Duration {
cfg := parseConfig(configJSON)
raw, ok := cfg["timeout_seconds"]
if !ok {
return fallback
}
switch value := raw.(type) {
case float64:
if value > 0 {
return time.Duration(value * float64(time.Second))
}
case int:
if value > 0 {
return time.Duration(value) * time.Second
}
}
return fallback
}
func parseConfig(configJSON []byte) map[string]any {
if len(configJSON) == 0 {
return map[string]any{}
}
var cfg map[string]any
if err := json.Unmarshal(configJSON, &cfg); err != nil || cfg == nil {
return map[string]any{}
}
return cfg
}
func stringValue(raw any) string {
if value, ok := raw.(string); ok {
return strings.TrimSpace(value)
}
return ""
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}