mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refact: go mcp tool in containerd
This commit is contained in:
+27
-159
@@ -2,15 +2,11 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/chat"
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
@@ -25,58 +21,6 @@ import (
|
||||
"github.com/memohai/memoh/internal/server"
|
||||
)
|
||||
|
||||
type resolverTextEmbedder struct {
|
||||
resolver *embeddings.Resolver
|
||||
modelID string
|
||||
dims int
|
||||
}
|
||||
|
||||
func (e *resolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) {
|
||||
result, err := e.resolver.Embed(ctx, embeddings.Request{
|
||||
Type: embeddings.TypeText,
|
||||
Model: e.modelID,
|
||||
Input: embeddings.Input{Text: input},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Embedding, nil
|
||||
}
|
||||
|
||||
func (e *resolverTextEmbedder) Dimensions() int {
|
||||
return e.dims
|
||||
}
|
||||
|
||||
func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) {
|
||||
candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return nil, models.GetResponse{}, models.GetResponse{}, false, err
|
||||
}
|
||||
vectors := map[string]int{}
|
||||
var textModel models.GetResponse
|
||||
var multimodalModel models.GetResponse
|
||||
for _, model := range candidates {
|
||||
if model.Dimensions > 0 && model.ModelID != "" {
|
||||
vectors[model.ModelID] = model.Dimensions
|
||||
}
|
||||
if model.IsMultimodal {
|
||||
if multimodalModel.ModelID == "" {
|
||||
multimodalModel = model
|
||||
}
|
||||
continue
|
||||
}
|
||||
if textModel.ModelID == "" {
|
||||
textModel = model
|
||||
}
|
||||
}
|
||||
|
||||
hasTextModel := textModel.ModelID != ""
|
||||
hasMultimodalModel := multimodalModel.ModelID != ""
|
||||
hasAnyModel := hasTextModel || hasMultimodalModel
|
||||
|
||||
return vectors, textModel, multimodalModel, hasAnyModel, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
cfgPath := os.Getenv("CONFIG_PATH")
|
||||
@@ -98,7 +42,11 @@ func main() {
|
||||
addr = value
|
||||
}
|
||||
|
||||
factory := ctr.DefaultClientFactory{SocketPath: cfg.Containerd.SocketPath}
|
||||
socketPath := cfg.Containerd.SocketPath
|
||||
if value := os.Getenv("CONTAINERD_SOCKET"); value != "" {
|
||||
socketPath = value
|
||||
}
|
||||
factory := ctr.DefaultClientFactory{SocketPath: socketPath}
|
||||
client, err := factory.New(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("connect containerd: %v", err)
|
||||
@@ -108,6 +56,9 @@ func main() {
|
||||
service := ctr.NewDefaultService(client, cfg.Containerd.Namespace)
|
||||
manager := mcp.NewManager(service, cfg.MCP)
|
||||
|
||||
pingHandler := handlers.NewPingHandler()
|
||||
containerdHandler := handlers.NewContainerdHandler(service, cfg.MCP, cfg.Containerd.Namespace)
|
||||
|
||||
conn, err := db.Open(ctx, cfg.Postgres)
|
||||
if err != nil {
|
||||
log.Fatalf("db connect: %v", err)
|
||||
@@ -117,35 +68,34 @@ func main() {
|
||||
queries := dbsqlc.New(conn)
|
||||
modelsService := models.NewService(queries)
|
||||
|
||||
pingHandler := handlers.NewPingHandler()
|
||||
authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn)
|
||||
|
||||
|
||||
// Initialize chat resolver for both chat and memory operations
|
||||
chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second)
|
||||
|
||||
|
||||
// Create LLM client for memory operations using chat provider
|
||||
var llmClient memory.LLM
|
||||
memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries)
|
||||
memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, modelsService, queries)
|
||||
if err != nil {
|
||||
log.Fatalf("select memory model: %v\nPlease configure at least one chat model in the database.", err)
|
||||
}
|
||||
|
||||
|
||||
log.Printf("Using memory model: %s (provider: %s)", memoryModel.ModelID, memoryProvider.ClientType)
|
||||
provider, err := createChatProvider(memoryProvider, 30*time.Second)
|
||||
provider, err := chat.CreateProvider(memoryProvider, 30*time.Second)
|
||||
if err != nil {
|
||||
log.Fatalf("create memory provider: %v", err)
|
||||
}
|
||||
llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID)
|
||||
|
||||
|
||||
resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second)
|
||||
vectors, textModel, multimodalModel, hasModels, err := collectEmbeddingVectors(ctx, modelsService)
|
||||
vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService)
|
||||
if err != nil {
|
||||
log.Fatalf("embedding models: %v", err)
|
||||
}
|
||||
|
||||
|
||||
var memoryService *memory.Service
|
||||
var memoryHandler *handlers.MemoryHandler
|
||||
|
||||
|
||||
if !hasModels {
|
||||
log.Println("WARNING: No embedding models configured. Memory service will not be available.")
|
||||
log.Println("You can add embedding models via the /models API endpoint.")
|
||||
@@ -157,17 +107,17 @@ func main() {
|
||||
if multimodalModel.ModelID == "" {
|
||||
log.Println("WARNING: No multimodal embedding model configured. Multimodal embedding features will be limited.")
|
||||
}
|
||||
|
||||
|
||||
var textEmbedder embeddings.Embedder
|
||||
var store *memory.QdrantStore
|
||||
|
||||
|
||||
if textModel.ModelID != "" && textModel.Dimensions > 0 {
|
||||
textEmbedder = &resolverTextEmbedder{
|
||||
resolver: resolver,
|
||||
modelID: textModel.ModelID,
|
||||
dims: textModel.Dimensions,
|
||||
textEmbedder = &embeddings.ResolverTextEmbedder{
|
||||
Resolver: resolver,
|
||||
ModelID: textModel.ModelID,
|
||||
Dims: textModel.Dimensions,
|
||||
}
|
||||
|
||||
|
||||
if len(vectors) > 0 {
|
||||
store, err = memory.NewQdrantStoreWithVectors(
|
||||
cfg.Qdrant.BaseURL,
|
||||
@@ -192,103 +142,21 @@ func main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
memoryService = memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID)
|
||||
memoryHandler = handlers.NewMemoryHandler(memoryService)
|
||||
}
|
||||
embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries)
|
||||
fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace)
|
||||
swaggerHandler := handlers.NewSwaggerHandler()
|
||||
chatHandler := handlers.NewChatHandler(chatResolver)
|
||||
|
||||
|
||||
// Initialize providers and models handlers
|
||||
providersService := providers.NewService(queries)
|
||||
providersHandler := handlers.NewProvidersHandler(providersService)
|
||||
modelsHandler := handlers.NewModelsHandler(modelsService)
|
||||
|
||||
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, fsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler)
|
||||
srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler, containerdHandler)
|
||||
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatalf("server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// selectMemoryModel selects a chat model for memory operations
|
||||
func selectMemoryModel(ctx context.Context, modelsService *models.Service, queries *dbsqlc.Queries) (models.GetResponse, dbsqlc.LlmProvider, error) {
|
||||
// First try to get the memory-enabled model
|
||||
memoryModel, err := modelsService.GetByEnableAs(ctx, models.EnableAsMemory)
|
||||
if err == nil {
|
||||
provider, err := fetchProviderByID(ctx, queries, memoryModel.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, dbsqlc.LlmProvider{}, err
|
||||
}
|
||||
return memoryModel, provider, nil
|
||||
}
|
||||
|
||||
// Fallback to chat model
|
||||
chatModel, err := modelsService.GetByEnableAs(ctx, models.EnableAsChat)
|
||||
if err == nil {
|
||||
provider, err := fetchProviderByID(ctx, queries, chatModel.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, dbsqlc.LlmProvider{}, err
|
||||
}
|
||||
return chatModel, provider, nil
|
||||
}
|
||||
|
||||
// If no enabled models, try to find any chat model
|
||||
candidates, err := modelsService.ListByType(ctx, models.ModelTypeChat)
|
||||
if err != nil || len(candidates) == 0 {
|
||||
return models.GetResponse{}, dbsqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
|
||||
}
|
||||
|
||||
selected := candidates[0]
|
||||
provider, err := fetchProviderByID(ctx, queries, selected.LlmProviderID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, dbsqlc.LlmProvider{}, err
|
||||
}
|
||||
return selected, provider, nil
|
||||
}
|
||||
|
||||
// fetchProviderByID fetches a provider by ID
|
||||
func fetchProviderByID(ctx context.Context, queries *dbsqlc.Queries, providerID string) (dbsqlc.LlmProvider, error) {
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
return dbsqlc.LlmProvider{}, fmt.Errorf("provider id missing")
|
||||
}
|
||||
parsed, err := uuid.Parse(providerID)
|
||||
if err != nil {
|
||||
return dbsqlc.LlmProvider{}, err
|
||||
}
|
||||
pgID := pgtype.UUID{Valid: true}
|
||||
copy(pgID.Bytes[:], parsed[:])
|
||||
return queries.GetLlmProviderByID(ctx, pgID)
|
||||
}
|
||||
|
||||
// createChatProvider creates a chat provider instance
|
||||
func createChatProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (chat.Provider, error) {
|
||||
clientType := strings.ToLower(strings.TrimSpace(provider.ClientType))
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
switch clientType {
|
||||
case chat.ProviderOpenAI, chat.ProviderOpenAICompat:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("openai api key is required")
|
||||
}
|
||||
return chat.NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout)
|
||||
case chat.ProviderAnthropic:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("anthropic api key is required")
|
||||
}
|
||||
return chat.NewAnthropicProvider(provider.ApiKey, timeout)
|
||||
case chat.ProviderGoogle:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("google api key is required")
|
||||
}
|
||||
return chat.NewGoogleProvider(provider.ApiKey, timeout)
|
||||
case chat.ProviderOllama:
|
||||
return chat.NewOllamaProvider(provider.BaseUrl, timeout)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider type: %s", clientType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
FROM golang:1.25-alpine AS build
|
||||
|
||||
WORKDIR /src
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
ARG TARGETARCH
|
||||
ARG COMMIT_HASH=unknown
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH:-amd64} \
|
||||
go build -trimpath -ldflags "-s -w -X main.commitHash=${COMMIT_HASH}" -o /out/mcp ./cmd/mcp
|
||||
|
||||
FROM busybox:latest
|
||||
COPY --from=build /out/mcp /mcp
|
||||
ENTRYPOINT ["/mcp"]
|
||||
+15
-153
@@ -2,165 +2,27 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
gomcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
var (
|
||||
commitHash = "unknown"
|
||||
version = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
usage()
|
||||
return
|
||||
if version == "unknown" {
|
||||
version = "v0.0.0-dev+" + commitHash
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
cfgPath := os.Getenv("CONFIG_PATH")
|
||||
cfg, err := config.Load(cfgPath)
|
||||
if err != nil {
|
||||
log.Fatalf("load config: %v", err)
|
||||
}
|
||||
|
||||
factory := ctr.DefaultClientFactory{SocketPath: cfg.Containerd.SocketPath}
|
||||
client, err := factory.New(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("connect containerd: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
service := ctr.NewDefaultService(client, cfg.Containerd.Namespace)
|
||||
manager := mcp.NewManager(service, cfg.MCP)
|
||||
|
||||
switch os.Args[1] {
|
||||
case "init":
|
||||
if err := manager.Init(ctx); err != nil {
|
||||
log.Fatalf("init: %v", err)
|
||||
}
|
||||
case "list":
|
||||
users, err := manager.ListUsers(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("list: %v", err)
|
||||
}
|
||||
for _, user := range users {
|
||||
fmt.Println(user)
|
||||
}
|
||||
case "create":
|
||||
userID := argAt(2)
|
||||
if err := manager.EnsureUser(ctx, userID); err != nil {
|
||||
log.Fatalf("create: %v", err)
|
||||
}
|
||||
case "start":
|
||||
userID := argAt(2)
|
||||
if err := manager.Start(ctx, userID); err != nil {
|
||||
log.Fatalf("start: %v", err)
|
||||
}
|
||||
case "stop":
|
||||
stopCmd(ctx, manager, os.Args[2:])
|
||||
case "delete":
|
||||
userID := argAt(2)
|
||||
if err := manager.Delete(ctx, userID); err != nil {
|
||||
log.Fatalf("delete: %v", err)
|
||||
}
|
||||
case "exec":
|
||||
withDB(ctx, cfg.Postgres, manager, func() {
|
||||
execCmd(ctx, manager, os.Args[2:])
|
||||
})
|
||||
default:
|
||||
usage()
|
||||
server := gomcp.NewServer(
|
||||
&gomcp.Implementation{Name: "memoh-mcp", Version: version},
|
||||
nil,
|
||||
)
|
||||
mcp.RegisterTools(server)
|
||||
if err := server.Run(context.Background(), &gomcp.StdioTransport{}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func stopCmd(ctx context.Context, manager *mcp.Manager, args []string) {
|
||||
fs := flag.NewFlagSet("stop", flag.ExitOnError)
|
||||
timeout := fs.Duration("timeout", 10*time.Second, "stop timeout")
|
||||
fs.Parse(args)
|
||||
|
||||
userID := fs.Arg(0)
|
||||
if userID == "" {
|
||||
log.Fatalf("stop: user id required")
|
||||
}
|
||||
|
||||
if err := manager.Stop(ctx, userID, *timeout); err != nil {
|
||||
log.Fatalf("stop: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func execCmd(ctx context.Context, manager *mcp.Manager, args []string) {
|
||||
fs := flag.NewFlagSet("exec", flag.ExitOnError)
|
||||
var envs stringSlice
|
||||
cwd := fs.String("cwd", "", "working directory")
|
||||
tty := fs.Bool("tty", false, "allocate a tty")
|
||||
fs.Var(&envs, "env", "environment variable, can be repeated")
|
||||
fs.Parse(args)
|
||||
|
||||
userID := fs.Arg(0)
|
||||
cmdArgs := fs.Args()[1:]
|
||||
if userID == "" || len(cmdArgs) == 0 {
|
||||
log.Fatalf("exec: user id and command required")
|
||||
}
|
||||
|
||||
result, err := manager.Exec(ctx, mcp.ExecRequest{
|
||||
UserID: userID,
|
||||
Command: cmdArgs,
|
||||
Env: envs,
|
||||
WorkDir: *cwd,
|
||||
Terminal: *tty,
|
||||
UseStdio: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("exec: %v", err)
|
||||
}
|
||||
if result.ExitCode != 0 {
|
||||
os.Exit(int(result.ExitCode))
|
||||
}
|
||||
}
|
||||
|
||||
func argAt(index int) string {
|
||||
if len(os.Args) <= index {
|
||||
log.Fatalf("missing argument")
|
||||
}
|
||||
return os.Args[index]
|
||||
}
|
||||
|
||||
type stringSlice []string
|
||||
|
||||
func (s *stringSlice) String() string {
|
||||
return fmt.Sprintf("%v", []string(*s))
|
||||
}
|
||||
|
||||
func (s *stringSlice) Set(value string) error {
|
||||
*s = append(*s, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Println("Usage: mcp <command> [args]")
|
||||
fmt.Println()
|
||||
fmt.Println("Commands:")
|
||||
fmt.Println(" init")
|
||||
fmt.Println(" list")
|
||||
fmt.Println(" create <userID>")
|
||||
fmt.Println(" start <userID>")
|
||||
fmt.Println(" stop <userID> [--timeout=10s]")
|
||||
fmt.Println(" delete <userID>")
|
||||
fmt.Println(" exec <userID> [--cwd=DIR] [--tty] [--env=K=V] -- <cmd> [args...]")
|
||||
fmt.Println(" version-create <userID>")
|
||||
fmt.Println(" version-list <userID>")
|
||||
fmt.Println(" version-rollback <userID> <version>")
|
||||
}
|
||||
|
||||
func withDB(ctx context.Context, cfg config.PostgresConfig, manager *mcp.Manager, fn func()) {
|
||||
conn, err := db.Open(ctx, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("db connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
manager.WithDB(conn)
|
||||
fn()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user