mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix: mcp
This commit is contained in:
+2
-1
@@ -97,4 +97,5 @@ memory.db
|
||||
|
||||
config.toml
|
||||
|
||||
.workdocs/
|
||||
.workdocs/
|
||||
data
|
||||
+1
-2
@@ -38,7 +38,7 @@ export const createAgent = ({
|
||||
const fs: HTTPMCPConnection = {
|
||||
type: 'http',
|
||||
name: 'fs',
|
||||
url: `http://localhost:8080/bots/${identity.botId}/container/fs`,
|
||||
url: `${auth.baseUrl}/bots/${identity.botId}/container/fs`,
|
||||
headers: {
|
||||
'Authorization': `Bearer ${auth.bearer}`,
|
||||
},
|
||||
@@ -65,7 +65,6 @@ export const createAgent = ({
|
||||
identity,
|
||||
})
|
||||
const defaultMCPConnections = getDefaultMCPConnections()
|
||||
console.log('defaultMCPConnections', defaultMCPConnections)
|
||||
const { tools: mcpTools, close: closeMCP } = await getMCPTools([
|
||||
...defaultMCPConnections,
|
||||
...mcpConnections,
|
||||
|
||||
+48
-30
@@ -1,43 +1,61 @@
|
||||
import { Elysia } from 'elysia'
|
||||
import { chatModule } from './modules/chat'
|
||||
import { corsMiddleware } from './middlewares/cors'
|
||||
import { errorMiddleware } from './middlewares/error'
|
||||
import { loadConfig } from './config'
|
||||
import { join } from 'path'
|
||||
import { Elysia } from "elysia";
|
||||
import { chatModule } from "./modules/chat";
|
||||
import { corsMiddleware } from "./middlewares/cors";
|
||||
import { errorMiddleware } from "./middlewares/error";
|
||||
import { loadConfig } from "./config";
|
||||
import { join } from "path";
|
||||
|
||||
const config = loadConfig('../config.toml')
|
||||
const config = loadConfig("../config.toml");
|
||||
|
||||
export type AuthFetcher = (url: string, options?: RequestInit) => Promise<Response>
|
||||
export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => {
|
||||
return async (url: string, options?: RequestInit) => {
|
||||
const requestOptions = options ?? {}
|
||||
const headers = new Headers(requestOptions.headers || {})
|
||||
if (bearer) {
|
||||
headers.set('Authorization', `Bearer ${bearer}`)
|
||||
}
|
||||
let baseUrl = ''
|
||||
if (!baseUrl) {
|
||||
baseUrl = 'http://127.0.0.1'
|
||||
}
|
||||
if (typeof config.server.addr === 'string' && config.server.addr.startsWith(':')) {
|
||||
baseUrl = `http://127.0.0.1${config.server.addr}`
|
||||
}
|
||||
return await fetch(join(baseUrl, url), {
|
||||
...requestOptions,
|
||||
headers,
|
||||
})
|
||||
export const getBraveConfig = () => {
|
||||
return {
|
||||
apiKey: config.brave.api_key ?? "",
|
||||
baseUrl: config.brave.base_url ?? "https://api.search.brave.com/res/v1/",
|
||||
}
|
||||
}
|
||||
|
||||
export const getBaseUrl = () => {
|
||||
let baseUrl = "";
|
||||
if (!baseUrl) {
|
||||
baseUrl = "http://127.0.0.1";
|
||||
}
|
||||
if (
|
||||
typeof config.server.addr === "string" &&
|
||||
config.server.addr.startsWith(":")
|
||||
) {
|
||||
baseUrl = `http://127.0.0.1${config.server.addr}`;
|
||||
}
|
||||
return baseUrl;
|
||||
};
|
||||
|
||||
export type AuthFetcher = (
|
||||
url: string,
|
||||
options?: RequestInit,
|
||||
) => Promise<Response>;
|
||||
export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => {
|
||||
return async (url: string, options?: RequestInit) => {
|
||||
const requestOptions = options ?? {};
|
||||
const headers = new Headers(requestOptions.headers || {});
|
||||
if (bearer) {
|
||||
headers.set("Authorization", `Bearer ${bearer}`);
|
||||
}
|
||||
|
||||
return await fetch(join(getBaseUrl(), url), {
|
||||
...requestOptions,
|
||||
headers,
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
const app = new Elysia()
|
||||
.use(corsMiddleware)
|
||||
.use(errorMiddleware)
|
||||
.use(chatModule)
|
||||
.listen({
|
||||
port: config.agent_gateway.port ?? 8081,
|
||||
hostname: config.agent_gateway.host ?? '127.0.0.1',
|
||||
})
|
||||
hostname: config.agent_gateway.host ?? "127.0.0.1",
|
||||
});
|
||||
|
||||
console.log(
|
||||
`Agent Gateway is running at ${app.server?.hostname}:${app.server?.port}`
|
||||
)
|
||||
`Agent Gateway is running at ${app.server?.hostname}:${app.server?.port}`,
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Elysia, sse } from 'elysia'
|
||||
import z from 'zod'
|
||||
import { createAgent } from '../agent'
|
||||
import { createAuthFetcher } from '../index'
|
||||
import { createAuthFetcher, getBaseUrl, getBraveConfig } from '../index'
|
||||
import { ModelConfig } from '../types'
|
||||
import { bearerMiddleware } from '../middlewares/bearer'
|
||||
import { AllowedActionModel, AttachmentModel, IdentityContextModel, MCPConnectionModel, ModelConfigModel, ScheduleModel } from '../models'
|
||||
@@ -35,7 +35,9 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
mcpConnections: body.mcpConnections,
|
||||
auth: {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
},
|
||||
brave: getBraveConfig(),
|
||||
}, authFetcher)
|
||||
return ask({
|
||||
query: body.query,
|
||||
@@ -62,7 +64,9 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
mcpConnections: body.mcpConnections,
|
||||
auth: {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
},
|
||||
brave: getBraveConfig(),
|
||||
}, authFetcher)
|
||||
for await (const action of stream({
|
||||
query: body.query,
|
||||
@@ -95,7 +99,9 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
mcpConnections: body.mcpConnections,
|
||||
auth: {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
},
|
||||
brave: getBraveConfig(),
|
||||
}, authFetcher)
|
||||
return triggerSchedule({
|
||||
schedule: body.schedule,
|
||||
|
||||
+15
-13
@@ -12,7 +12,7 @@ export const getMCPTools = async (connections: MCPConnection[]) => {
|
||||
headers: connection.headers,
|
||||
}
|
||||
})
|
||||
closeCallbacks.push(client.close)
|
||||
closeCallbacks.push(() => client.close())
|
||||
return await client.tools()
|
||||
}
|
||||
|
||||
@@ -24,26 +24,28 @@ export const getMCPTools = async (connections: MCPConnection[]) => {
|
||||
headers: connection.headers,
|
||||
}
|
||||
})
|
||||
closeCallbacks.push(client.close)
|
||||
closeCallbacks.push(() => client.close())
|
||||
return await client.tools()
|
||||
}
|
||||
|
||||
const getStdioTools = async (connection: StdioMCPConnection) => {
|
||||
// TODO: Implement stdio tools
|
||||
return []
|
||||
return {}
|
||||
}
|
||||
|
||||
const toolSets = await Promise.all(connections.map(connection => {
|
||||
switch (connection.type) {
|
||||
case 'http':
|
||||
return getHTTPTools(connection)
|
||||
case 'sse':
|
||||
return getSSETools(connection)
|
||||
case 'stdio':
|
||||
return getStdioTools(connection)
|
||||
}
|
||||
}))
|
||||
|
||||
return {
|
||||
tools: await Promise.all(connections.map(connection => {
|
||||
switch (connection.type) {
|
||||
case 'http':
|
||||
return getHTTPTools(connection)
|
||||
case 'sse':
|
||||
return getSSETools(connection)
|
||||
case 'stdio':
|
||||
return getStdioTools(connection)
|
||||
}
|
||||
})),
|
||||
tools: Object.assign({}, ...toolSets),
|
||||
close: async () => {
|
||||
await Promise.all(closeCallbacks.map(callback => callback()))
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ export interface IdentityContext {
|
||||
|
||||
export interface AgentAuthContext {
|
||||
bearer: string
|
||||
baseUrl: string
|
||||
}
|
||||
|
||||
export enum AgentAction {
|
||||
|
||||
@@ -93,6 +93,14 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error {
|
||||
Error: &mcptools.JSONRPCError{Code: -32601, Message: "method not found"},
|
||||
})
|
||||
}
|
||||
if len(req.ID) == 0 && strings.HasPrefix(req.Method, "notifications/") {
|
||||
if err := h.notifyMCPServer(c.Request().Context(), containerID, req); err != nil {
|
||||
h.logger.Error("mcp fs notify failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
// MCP Streamable HTTP spec: notifications must be answered with 202 Accepted and no body.
|
||||
return c.NoContent(http.StatusAccepted)
|
||||
}
|
||||
payload, err := h.callMCPServer(c.Request().Context(), containerID, req)
|
||||
if err != nil {
|
||||
h.logger.Error("mcp fs call failed", slog.Any("error", err), slog.String("method", req.Method), slog.String("bot_id", botID), slog.String("container_id", containerID))
|
||||
@@ -140,6 +148,14 @@ func (h *ContainerdHandler) callMCPServer(ctx context.Context, containerID strin
|
||||
return session.call(ctx, req)
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) notifyMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) error {
|
||||
session, err := h.getMCPSession(ctx, containerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return session.notify(ctx, req)
|
||||
}
|
||||
|
||||
type mcpSession struct {
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
@@ -420,6 +436,22 @@ func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error {
|
||||
payloads, err := buildMCPNotificationPayloads(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.writeMu.Lock()
|
||||
for _, payload := range payloads {
|
||||
if _, err := s.stdin.Write([]byte(payload + "\n")); err != nil {
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.writeMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildMCPPayloads(req mcptools.JSONRPCRequest, initOnce *sync.Once) ([]string, json.RawMessage, error) {
|
||||
if req.JSONRPC == "" {
|
||||
req.JSONRPC = "2.0"
|
||||
@@ -480,3 +512,17 @@ func buildMCPPayloads(req mcptools.JSONRPCRequest, initOnce *sync.Once) ([]strin
|
||||
payloads = append(payloads, string(reqBytes))
|
||||
return payloads, targetID, nil
|
||||
}
|
||||
|
||||
func buildMCPNotificationPayloads(req mcptools.JSONRPCRequest) ([]string, error) {
|
||||
if req.JSONRPC == "" {
|
||||
req.JSONRPC = "2.0"
|
||||
}
|
||||
if strings.TrimSpace(req.Method) == "" {
|
||||
return nil, fmt.Errorf("missing method")
|
||||
}
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{string(reqBytes)}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user