This commit is contained in:
Acbox
2026-02-08 01:02:04 +08:00
parent 4e661bae76
commit da671a658c
7 changed files with 120 additions and 47 deletions
+2 -1
View File
@@ -97,4 +97,5 @@ memory.db
config.toml
.workdocs/
.workdocs/
data
+1 -2
View File
@@ -38,7 +38,7 @@ export const createAgent = ({
const fs: HTTPMCPConnection = {
type: 'http',
name: 'fs',
url: `http://localhost:8080/bots/${identity.botId}/container/fs`,
url: `${auth.baseUrl}/bots/${identity.botId}/container/fs`,
headers: {
'Authorization': `Bearer ${auth.bearer}`,
},
@@ -65,7 +65,6 @@ export const createAgent = ({
identity,
})
const defaultMCPConnections = getDefaultMCPConnections()
console.log('defaultMCPConnections', defaultMCPConnections)
const { tools: mcpTools, close: closeMCP } = await getMCPTools([
...defaultMCPConnections,
...mcpConnections,
+48 -30
View File
@@ -1,43 +1,61 @@
import { Elysia } from 'elysia'
import { chatModule } from './modules/chat'
import { corsMiddleware } from './middlewares/cors'
import { errorMiddleware } from './middlewares/error'
import { loadConfig } from './config'
import { join } from 'path'
import { Elysia } from "elysia";
import { chatModule } from "./modules/chat";
import { corsMiddleware } from "./middlewares/cors";
import { errorMiddleware } from "./middlewares/error";
import { loadConfig } from "./config";
import { join } from "path";
const config = loadConfig('../config.toml')
const config = loadConfig("../config.toml");
export type AuthFetcher = (url: string, options?: RequestInit) => Promise<Response>
export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => {
return async (url: string, options?: RequestInit) => {
const requestOptions = options ?? {}
const headers = new Headers(requestOptions.headers || {})
if (bearer) {
headers.set('Authorization', `Bearer ${bearer}`)
}
let baseUrl = ''
if (!baseUrl) {
baseUrl = 'http://127.0.0.1'
}
if (typeof config.server.addr === 'string' && config.server.addr.startsWith(':')) {
baseUrl = `http://127.0.0.1${config.server.addr}`
}
return await fetch(join(baseUrl, url), {
...requestOptions,
headers,
})
export const getBraveConfig = () => {
return {
apiKey: config.brave.api_key ?? "",
baseUrl: config.brave.base_url ?? "https://api.search.brave.com/res/v1/",
}
}
export const getBaseUrl = () => {
let baseUrl = "";
if (!baseUrl) {
baseUrl = "http://127.0.0.1";
}
if (
typeof config.server.addr === "string" &&
config.server.addr.startsWith(":")
) {
baseUrl = `http://127.0.0.1${config.server.addr}`;
}
return baseUrl;
};
export type AuthFetcher = (
url: string,
options?: RequestInit,
) => Promise<Response>;
export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => {
return async (url: string, options?: RequestInit) => {
const requestOptions = options ?? {};
const headers = new Headers(requestOptions.headers || {});
if (bearer) {
headers.set("Authorization", `Bearer ${bearer}`);
}
return await fetch(join(getBaseUrl(), url), {
...requestOptions,
headers,
});
};
};
const app = new Elysia()
.use(corsMiddleware)
.use(errorMiddleware)
.use(chatModule)
.listen({
port: config.agent_gateway.port ?? 8081,
hostname: config.agent_gateway.host ?? '127.0.0.1',
})
hostname: config.agent_gateway.host ?? "127.0.0.1",
});
console.log(
`Agent Gateway is running at ${app.server?.hostname}:${app.server?.port}`
)
`Agent Gateway is running at ${app.server?.hostname}:${app.server?.port}`,
);
+7 -1
View File
@@ -1,7 +1,7 @@
import { Elysia, sse } from 'elysia'
import z from 'zod'
import { createAgent } from '../agent'
import { createAuthFetcher } from '../index'
import { createAuthFetcher, getBaseUrl, getBraveConfig } from '../index'
import { ModelConfig } from '../types'
import { bearerMiddleware } from '../middlewares/bearer'
import { AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models'
@@ -35,7 +35,9 @@ export const chatModule = new Elysia({ prefix: '/chat' })
mcpConnections: body.mcpConnections,
auth: {
bearer: bearer!,
baseUrl: getBaseUrl(),
},
brave: getBraveConfig(),
}, authFetcher)
return ask({
query: body.query,
@@ -62,7 +64,9 @@ export const chatModule = new Elysia({ prefix: '/chat' })
mcpConnections: body.mcpConnections,
auth: {
bearer: bearer!,
baseUrl: getBaseUrl(),
},
brave: getBraveConfig(),
}, authFetcher)
for await (const action of stream({
query: body.query,
@@ -95,7 +99,9 @@ export const chatModule = new Elysia({ prefix: '/chat' })
mcpConnections: body.mcpConnections,
auth: {
bearer: bearer!,
baseUrl: getBaseUrl(),
},
brave: getBraveConfig(),
}, authFetcher)
return triggerSchedule({
schedule: body.schedule,
+15 -13
View File
@@ -12,7 +12,7 @@ export const getMCPTools = async (connections: MCPConnection[]) => {
headers: connection.headers,
}
})
closeCallbacks.push(client.close)
closeCallbacks.push(() => client.close())
return await client.tools()
}
@@ -24,26 +24,28 @@ export const getMCPTools = async (connections: MCPConnection[]) => {
headers: connection.headers,
}
})
closeCallbacks.push(client.close)
closeCallbacks.push(() => client.close())
return await client.tools()
}
const getStdioTools = async (connection: StdioMCPConnection) => {
// TODO: Implement stdio tools
return []
return {}
}
const toolSets = await Promise.all(connections.map(connection => {
switch (connection.type) {
case 'http':
return getHTTPTools(connection)
case 'sse':
return getSSETools(connection)
case 'stdio':
return getStdioTools(connection)
}
}))
return {
tools: await Promise.all(connections.map(connection => {
switch (connection.type) {
case 'http':
return getHTTPTools(connection)
case 'sse':
return getSSETools(connection)
case 'stdio':
return getStdioTools(connection)
}
})),
tools: Object.assign({}, ...toolSets),
close: async () => {
await Promise.all(closeCallbacks.map(callback => callback()))
}
+1
View File
@@ -20,6 +20,7 @@ export interface IdentityContext {
export interface AgentAuthContext {
bearer: string
baseUrl: string
}
export enum AgentAction {
+46
View File
@@ -93,6 +93,14 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error {
Error: &mcptools.JSONRPCError{Code: -32601, Message: "method not found"},
})
}
if len(req.ID) == 0 && strings.HasPrefix(req.Method, "notifications/") {
if err := h.notifyMCPServer(c.Request().Context(), containerID, req); err != nil {
h.logger.Error("mcp fs notify failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID))
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
// MCP Streamable HTTP spec: notifications must be answered with 202 Accepted and no body.
return c.NoContent(http.StatusAccepted)
}
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
if err != nil {
h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID))
@@ -140,6 +148,14 @@ func (h *ContainerdHandler) callMCPServer(ctx context.Context, containerID strin
return session.call(ctx, req)
}
func (h *ContainerdHandler) notifyMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) error {
session, err := h.getMCPSession(ctx, containerID)
if err != nil {
return err
}
return session.notify(ctx, req)
}
type mcpSession struct {
stdin io.WriteCloser
stdout io.ReadCloser
@@ -420,6 +436,22 @@ func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map
}
}
func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error {
payloads, err := buildMCPNotificationPayloads(req)
if err != nil {
return err
}
s.writeMu.Lock()
for _, payload := range payloads {
if _, err := s.stdin.Write([]byte(payload + "\n")); err != nil {
s.writeMu.Unlock()
return err
}
}
s.writeMu.Unlock()
return nil
}
func buildMCPPayloads(req mcptools.JSONRPCRequest, initOnce *sync.Once) ([]string, json.RawMessage, error) {
if req.JSONRPC == "" {
req.JSONRPC = "2.0"
@@ -480,3 +512,17 @@ func buildMCPPayloads(req mcptools.JSONRPCRequest, initOnce *sync.Once) ([]strin
payloads = append(payloads, string(reqBytes))
return payloads, targetID, nil
}
func buildMCPNotificationPayloads(req mcptools.JSONRPCRequest) ([]string, error) {
if req.JSONRPC == "" {
req.JSONRPC = "2.0"
}
if strings.TrimSpace(req.Method) == "" {
return nil, fmt.Errorf("missing method")
}
reqBytes, err := json.Marshal(req)
if err != nil {
return nil, err
}
return []string{string(reqBytes)}, nil
}