diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 86037096..55b4d9f4 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -2,15 +2,11 @@ package main import ( "context" - "fmt" "log" "os" "strings" "time" - "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgtype" - "github.com/memohai/memoh/internal/chat" "github.com/memohai/memoh/internal/config" ctr "github.com/memohai/memoh/internal/containerd" @@ -25,58 +21,6 @@ import ( "github.com/memohai/memoh/internal/server" ) -type resolverTextEmbedder struct { - resolver *embeddings.Resolver - modelID string - dims int -} - -func (e *resolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) { - result, err := e.resolver.Embed(ctx, embeddings.Request{ - Type: embeddings.TypeText, - Model: e.modelID, - Input: embeddings.Input{Text: input}, - }) - if err != nil { - return nil, err - } - return result.Embedding, nil -} - -func (e *resolverTextEmbedder) Dimensions() int { - return e.dims -} - -func collectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) { - candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding) - if err != nil { - return nil, models.GetResponse{}, models.GetResponse{}, false, err - } - vectors := map[string]int{} - var textModel models.GetResponse - var multimodalModel models.GetResponse - for _, model := range candidates { - if model.Dimensions > 0 && model.ModelID != "" { - vectors[model.ModelID] = model.Dimensions - } - if model.IsMultimodal { - if multimodalModel.ModelID == "" { - multimodalModel = model - } - continue - } - if textModel.ModelID == "" { - textModel = model - } - } - - hasTextModel := textModel.ModelID != "" - hasMultimodalModel := multimodalModel.ModelID != "" - hasAnyModel := hasTextModel || hasMultimodalModel - - return vectors, textModel, multimodalModel, hasAnyModel, nil -} - func main() { ctx := context.Background() cfgPath := os.Getenv("CONFIG_PATH") @@ -98,7 +42,11 @@ func main() { addr = value } - factory := ctr.DefaultClientFactory{SocketPath: cfg.Containerd.SocketPath} + socketPath := cfg.Containerd.SocketPath + if value := os.Getenv("CONTAINERD_SOCKET"); value != "" { + socketPath = value + } + factory := ctr.DefaultClientFactory{SocketPath: socketPath} client, err := factory.New(ctx) if err != nil { log.Fatalf("connect containerd: %v", err) @@ -108,6 +56,9 @@ func main() { service := ctr.NewDefaultService(client, cfg.Containerd.Namespace) manager := mcp.NewManager(service, cfg.MCP) + pingHandler := handlers.NewPingHandler() + containerdHandler := handlers.NewContainerdHandler(service, cfg.MCP, cfg.Containerd.Namespace) + conn, err := db.Open(ctx, cfg.Postgres) if err != nil { log.Fatalf("db connect: %v", err) @@ -117,35 +68,34 @@ func main() { queries := dbsqlc.New(conn) modelsService := models.NewService(queries) - pingHandler := handlers.NewPingHandler() authHandler := handlers.NewAuthHandler(conn, cfg.Auth.JWTSecret, jwtExpiresIn) - + // Initialize chat resolver for both chat and memory operations chatResolver := chat.NewResolver(modelsService, queries, 30*time.Second) - + // Create LLM client for memory operations using chat provider var llmClient memory.LLM - memoryModel, memoryProvider, err := selectMemoryModel(ctx, modelsService, queries) + memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, modelsService, queries) if err != nil { log.Fatalf("select memory model: %v\nPlease configure at least one chat model in the database.", err) } - + log.Printf("Using memory model: %s (provider: %s)", memoryModel.ModelID, memoryProvider.ClientType) - provider, err := createChatProvider(memoryProvider, 30*time.Second) + provider, err := chat.CreateProvider(memoryProvider, 30*time.Second) if err != nil { log.Fatalf("create memory provider: %v", err) } llmClient = memory.NewProviderLLMClient(provider, memoryModel.ModelID) - + resolver := embeddings.NewResolver(modelsService, queries, 10*time.Second) - vectors, textModel, multimodalModel, hasModels, err := collectEmbeddingVectors(ctx, modelsService) + vectors, textModel, multimodalModel, hasModels, err := embeddings.CollectEmbeddingVectors(ctx, modelsService) if err != nil { log.Fatalf("embedding models: %v", err) } - + var memoryService *memory.Service var memoryHandler *handlers.MemoryHandler - + if !hasModels { log.Println("WARNING: No embedding models configured. Memory service will not be available.") log.Println("You can add embedding models via the /models API endpoint.") @@ -157,17 +107,17 @@ func main() { if multimodalModel.ModelID == "" { log.Println("WARNING: No multimodal embedding model configured. Multimodal embedding features will be limited.") } - + var textEmbedder embeddings.Embedder var store *memory.QdrantStore - + if textModel.ModelID != "" && textModel.Dimensions > 0 { - textEmbedder = &resolverTextEmbedder{ - resolver: resolver, - modelID: textModel.ModelID, - dims: textModel.Dimensions, + textEmbedder = &embeddings.ResolverTextEmbedder{ + Resolver: resolver, + ModelID: textModel.ModelID, + Dims: textModel.Dimensions, } - + if len(vectors) > 0 { store, err = memory.NewQdrantStoreWithVectors( cfg.Qdrant.BaseURL, @@ -192,103 +142,21 @@ func main() { } } } - + memoryService = memory.NewService(llmClient, textEmbedder, store, resolver, textModel.ModelID, multimodalModel.ModelID) memoryHandler = handlers.NewMemoryHandler(memoryService) } embeddingsHandler := handlers.NewEmbeddingsHandler(modelsService, queries) - fsHandler := handlers.NewFSHandler(service, manager, cfg.MCP, cfg.Containerd.Namespace) swaggerHandler := handlers.NewSwaggerHandler() chatHandler := handlers.NewChatHandler(chatResolver) - + // Initialize providers and models handlers providersService := providers.NewService(queries) providersHandler := handlers.NewProvidersHandler(providersService) modelsHandler := handlers.NewModelsHandler(modelsService) - - srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, fsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler) + srv := server.NewServer(addr, cfg.Auth.JWTSecret, pingHandler, authHandler, memoryHandler, embeddingsHandler, swaggerHandler, chatHandler, providersHandler, modelsHandler, containerdHandler) if err := srv.Start(); err != nil { log.Fatalf("server failed: %v", err) } } - -// selectMemoryModel selects a chat model for memory operations -func selectMemoryModel(ctx context.Context, modelsService *models.Service, queries *dbsqlc.Queries) (models.GetResponse, dbsqlc.LlmProvider, error) { - // First try to get the memory-enabled model - memoryModel, err := modelsService.GetByEnableAs(ctx, models.EnableAsMemory) - if err == nil { - provider, err := fetchProviderByID(ctx, queries, memoryModel.LlmProviderID) - if err != nil { - return models.GetResponse{}, dbsqlc.LlmProvider{}, err - } - return memoryModel, provider, nil - } - - // Fallback to chat model - chatModel, err := modelsService.GetByEnableAs(ctx, models.EnableAsChat) - if err == nil { - provider, err := fetchProviderByID(ctx, queries, chatModel.LlmProviderID) - if err != nil { - return models.GetResponse{}, dbsqlc.LlmProvider{}, err - } - return chatModel, provider, nil - } - - // If no enabled models, try to find any chat model - candidates, err := modelsService.ListByType(ctx, models.ModelTypeChat) - if err != nil || len(candidates) == 0 { - return models.GetResponse{}, dbsqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") - } - - selected := candidates[0] - provider, err := fetchProviderByID(ctx, queries, selected.LlmProviderID) - if err != nil { - return models.GetResponse{}, dbsqlc.LlmProvider{}, err - } - return selected, provider, nil -} - -// fetchProviderByID fetches a provider by ID -func fetchProviderByID(ctx context.Context, queries *dbsqlc.Queries, providerID string) (dbsqlc.LlmProvider, error) { - if strings.TrimSpace(providerID) == "" { - return dbsqlc.LlmProvider{}, fmt.Errorf("provider id missing") - } - parsed, err := uuid.Parse(providerID) - if err != nil { - return dbsqlc.LlmProvider{}, err - } - pgID := pgtype.UUID{Valid: true} - copy(pgID.Bytes[:], parsed[:]) - return queries.GetLlmProviderByID(ctx, pgID) -} - -// createChatProvider creates a chat provider instance -func createChatProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (chat.Provider, error) { - clientType := strings.ToLower(strings.TrimSpace(provider.ClientType)) - if timeout <= 0 { - timeout = 30 * time.Second - } - - switch clientType { - case chat.ProviderOpenAI, chat.ProviderOpenAICompat: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, fmt.Errorf("openai api key is required") - } - return chat.NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout) - case chat.ProviderAnthropic: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, fmt.Errorf("anthropic api key is required") - } - return chat.NewAnthropicProvider(provider.ApiKey, timeout) - case chat.ProviderGoogle: - if strings.TrimSpace(provider.ApiKey) == "" { - return nil, fmt.Errorf("google api key is required") - } - return chat.NewGoogleProvider(provider.ApiKey, timeout) - case chat.ProviderOllama: - return chat.NewOllamaProvider(provider.BaseUrl, timeout) - default: - return nil, fmt.Errorf("unsupported provider type: %s", clientType) - } -} diff --git a/cmd/mcp/Dockerfile b/cmd/mcp/Dockerfile new file mode 100644 index 00000000..fbaa82ee --- /dev/null +++ b/cmd/mcp/Dockerfile @@ -0,0 +1,15 @@ +FROM golang:1.25-alpine AS build + +WORKDIR /src +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +ARG TARGETARCH +ARG COMMIT_HASH=unknown +RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH:-amd64} \ + go build -trimpath -ldflags "-s -w -X main.commitHash=${COMMIT_HASH}" -o /out/mcp ./cmd/mcp + +FROM busybox:latest +COPY --from=build /out/mcp /mcp +ENTRYPOINT ["/mcp"] diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index c3580b15..2a8f3f20 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -2,165 +2,27 @@ package main import ( "context" - "flag" - "fmt" "log" - "os" - "time" - "github.com/memohai/memoh/internal/config" - ctr "github.com/memohai/memoh/internal/containerd" - "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/mcp" + gomcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + commitHash = "unknown" + version = "unknown" ) func main() { - if len(os.Args) < 2 { - usage() - return + if version == "unknown" { + version = "v0.0.0-dev+" + commitHash } - - ctx := context.Background() - cfgPath := os.Getenv("CONFIG_PATH") - cfg, err := config.Load(cfgPath) - if err != nil { - log.Fatalf("load config: %v", err) - } - - factory := ctr.DefaultClientFactory{SocketPath: cfg.Containerd.SocketPath} - client, err := factory.New(ctx) - if err != nil { - log.Fatalf("connect containerd: %v", err) - } - defer client.Close() - - service := ctr.NewDefaultService(client, cfg.Containerd.Namespace) - manager := mcp.NewManager(service, cfg.MCP) - - switch os.Args[1] { - case "init": - if err := manager.Init(ctx); err != nil { - log.Fatalf("init: %v", err) - } - case "list": - users, err := manager.ListUsers(ctx) - if err != nil { - log.Fatalf("list: %v", err) - } - for _, user := range users { - fmt.Println(user) - } - case "create": - userID := argAt(2) - if err := manager.EnsureUser(ctx, userID); err != nil { - log.Fatalf("create: %v", err) - } - case "start": - userID := argAt(2) - if err := manager.Start(ctx, userID); err != nil { - log.Fatalf("start: %v", err) - } - case "stop": - stopCmd(ctx, manager, os.Args[2:]) - case "delete": - userID := argAt(2) - if err := manager.Delete(ctx, userID); err != nil { - log.Fatalf("delete: %v", err) - } - case "exec": - withDB(ctx, cfg.Postgres, manager, func() { - execCmd(ctx, manager, os.Args[2:]) - }) - default: - usage() + server := gomcp.NewServer( + &gomcp.Implementation{Name: "memoh-mcp", Version: version}, + nil, + ) + mcp.RegisterTools(server) + if err := server.Run(context.Background(), &gomcp.StdioTransport{}); err != nil { + log.Fatal(err) } } - -func stopCmd(ctx context.Context, manager *mcp.Manager, args []string) { - fs := flag.NewFlagSet("stop", flag.ExitOnError) - timeout := fs.Duration("timeout", 10*time.Second, "stop timeout") - fs.Parse(args) - - userID := fs.Arg(0) - if userID == "" { - log.Fatalf("stop: user id required") - } - - if err := manager.Stop(ctx, userID, *timeout); err != nil { - log.Fatalf("stop: %v", err) - } -} - -func execCmd(ctx context.Context, manager *mcp.Manager, args []string) { - fs := flag.NewFlagSet("exec", flag.ExitOnError) - var envs stringSlice - cwd := fs.String("cwd", "", "working directory") - tty := fs.Bool("tty", false, "allocate a tty") - fs.Var(&envs, "env", "environment variable, can be repeated") - fs.Parse(args) - - userID := fs.Arg(0) - cmdArgs := fs.Args()[1:] - if userID == "" || len(cmdArgs) == 0 { - log.Fatalf("exec: user id and command required") - } - - result, err := manager.Exec(ctx, mcp.ExecRequest{ - UserID: userID, - Command: cmdArgs, - Env: envs, - WorkDir: *cwd, - Terminal: *tty, - UseStdio: true, - }) - if err != nil { - log.Fatalf("exec: %v", err) - } - if result.ExitCode != 0 { - os.Exit(int(result.ExitCode)) - } -} - -func argAt(index int) string { - if len(os.Args) <= index { - log.Fatalf("missing argument") - } - return os.Args[index] -} - -type stringSlice []string - -func (s *stringSlice) String() string { - return fmt.Sprintf("%v", []string(*s)) -} - -func (s *stringSlice) Set(value string) error { - *s = append(*s, value) - return nil -} - -func usage() { - fmt.Println("Usage: mcp [args]") - fmt.Println() - fmt.Println("Commands:") - fmt.Println(" init") - fmt.Println(" list") - fmt.Println(" create ") - fmt.Println(" start ") - fmt.Println(" stop [--timeout=10s]") - fmt.Println(" delete ") - fmt.Println(" exec [--cwd=DIR] [--tty] [--env=K=V] -- [args...]") - fmt.Println(" version-create ") - fmt.Println(" version-list ") - fmt.Println(" version-rollback ") -} - -func withDB(ctx context.Context, cfg config.PostgresConfig, manager *mcp.Manager, fn func()) { - conn, err := db.Open(ctx, cfg) - if err != nil { - log.Fatalf("db connect: %v", err) - } - defer conn.Close() - manager.WithDB(conn) - fn() -} diff --git a/config.toml.example b/config.toml.example index e38cfa7e..ac0f00a7 100644 --- a/config.toml.example +++ b/config.toml.example @@ -15,7 +15,7 @@ namespace = "default" [mcp] busybox_image = "docker.io/library/busybox:latest" -snapshotter = "" +snapshotter = "overlayfs" data_root = "data" data_mount = "/data" diff --git a/docs/docs.go b/docs/docs.go index f9b311a5..92be2126 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -199,61 +199,28 @@ const docTemplate = `{ } } }, - "/fs/apply_patch": { + "/mcp/containers": { "post": { - "description": "Apply a unified diff patch to a file under the user data mount", "tags": [ - "fs" + "containerd" ], - "summary": "Apply unified diff patch", + "summary": "Create and start MCP container", "parameters": [ { - "description": "Patch payload", + "description": "Create container payload", "name": "payload", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/handlers.ApplyPatchRequest" + "$ref": "#/definitions/handlers.CreateContainerRequest" } } ], - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/fs/commit": { - "post": { - "description": "Create a new version snapshot for the user container", - "tags": [ - "fs" - ], - "summary": "Commit a filesystem snapshot", "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/handlers.CommitResponse" + "$ref": "#/definitions/handlers.CreateContainerResponse" } }, "400": { @@ -262,12 +229,6 @@ const docTemplate = `{ "$ref": "#/definitions/handlers.ErrorResponse" } }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, "500": { "description": "Internal Server Error", "schema": { @@ -277,34 +238,24 @@ const docTemplate = `{ } } }, - "/fs/diff": { - "get": { - "description": "Produce a unified diff between a version snapshot and current data", + "/mcp/containers/{id}": { + "delete": { "tags": [ - "fs" + "containerd" ], - "summary": "Diff against a version snapshot", + "summary": "Delete MCP container", "parameters": [ { "type": "string", - "description": "Path under data mount", - "name": "path", - "in": "query" - }, - { - "type": "integer", - "description": "Version number", - "name": "version", - "in": "query", + "description": "Container ID", + "name": "id", + "in": "path", "required": true } ], "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.DiffResponse" - } + "204": { + "description": "No Content" }, "400": { "description": "Bad Request", @@ -327,119 +278,29 @@ const docTemplate = `{ } } }, - "/fs/list": { - "get": { - "description": "List files under the user data mount", + "/mcp/snapshots": { + "post": { "tags": [ - "fs" + "containerd" ], - "summary": "List directory contents", + "summary": "Create container snapshot", "parameters": [ { - "type": "string", - "description": "Path under data mount", - "name": "path", - "in": "query" - }, - { - "type": "boolean", - "description": "Recursive listing", - "name": "recursive", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ListResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/fs/read": { - "get": { - "description": "Read a file under the user data mount", - "tags": [ - "fs" - ], - "summary": "Read file content", - "parameters": [ - { - "type": "string", - "description": "Path under data mount", - "name": "path", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ReadResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/fs/write_atomic": { - "put": { - "description": "Atomically replace a file under the user data mount", - "tags": [ - "fs" - ], - "summary": "Write file atomically", - "parameters": [ - { - "description": "Write payload", + "description": "Create snapshot payload", "name": "payload", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/handlers.WriteAtomicRequest" + "$ref": "#/definitions/handlers.CreateSnapshotRequest" } } ], "responses": { - "204": { - "description": "No Content" + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.CreateSnapshotResponse" + } }, "400": { "description": "Bad Request", @@ -1657,45 +1518,56 @@ const docTemplate = `{ } } }, - "handlers.ApplyPatchRequest": { + "handlers.CreateContainerRequest": { "type": "object", "properties": { - "patch": { + "container_id": { "type": "string" }, - "path": { + "snapshotter": { "type": "string" } } }, - "handlers.CommitResponse": { + "handlers.CreateContainerResponse": { "type": "object", "properties": { - "created_at": { + "container_id": { "type": "string" }, - "id": { + "image": { "type": "string" }, - "snapshot_id": { + "snapshotter": { "type": "string" }, - "version": { - "type": "integer" + "started": { + "type": "boolean" } } }, - "handlers.DiffResponse": { + "handlers.CreateSnapshotRequest": { "type": "object", "properties": { - "diff": { + "container_id": { "type": "string" }, - "path": { + "snapshot_name": { + "type": "string" + } + } + }, + "handlers.CreateSnapshotResponse": { + "type": "object", + "properties": { + "container_id": { "type": "string" }, - "version": { - "type": "integer" + "snapshot_name": { + "type": "string" + }, + "snapshotter": { + "type": "string" } } }, @@ -1784,40 +1656,6 @@ const docTemplate = `{ } } }, - "handlers.FileEntry": { - "type": "object", - "properties": { - "is_dir": { - "type": "boolean" - }, - "mod_time": { - "type": "string" - }, - "mode": { - "type": "integer" - }, - "path": { - "type": "string" - }, - "size": { - "type": "integer" - } - } - }, - "handlers.ListResponse": { - "type": "object", - "properties": { - "entries": { - "type": "array", - "items": { - "$ref": "#/definitions/handlers.FileEntry" - } - }, - "path": { - "type": "string" - } - } - }, "handlers.LoginRequest": { "type": "object", "properties": { @@ -1849,49 +1687,6 @@ const docTemplate = `{ } } }, - "handlers.ReadResponse": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "encoding": { - "type": "string" - }, - "mod_time": { - "type": "string" - }, - "mode": { - "type": "integer" - }, - "path": { - "type": "string" - }, - "size": { - "type": "integer" - } - } - }, - "handlers.WriteAtomicRequest": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "encoding": { - "type": "string" - }, - "mode": { - "type": "integer" - }, - "mtime": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, "memory.AddRequest": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index 2192a103..42f97d50 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -188,61 +188,28 @@ } } }, - "/fs/apply_patch": { + "/mcp/containers": { "post": { - "description": "Apply a unified diff patch to a file under the user data mount", "tags": [ - "fs" + "containerd" ], - "summary": "Apply unified diff patch", + "summary": "Create and start MCP container", "parameters": [ { - "description": "Patch payload", + "description": "Create container payload", "name": "payload", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/handlers.ApplyPatchRequest" + "$ref": "#/definitions/handlers.CreateContainerRequest" } } ], - "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/fs/commit": { - "post": { - "description": "Create a new version snapshot for the user container", - "tags": [ - "fs" - ], - "summary": "Commit a filesystem snapshot", "responses": { "200": { "description": "OK", "schema": { - "$ref": "#/definitions/handlers.CommitResponse" + "$ref": "#/definitions/handlers.CreateContainerResponse" } }, "400": { @@ -251,12 +218,6 @@ "$ref": "#/definitions/handlers.ErrorResponse" } }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, "500": { "description": "Internal Server Error", "schema": { @@ -266,34 +227,24 @@ } } }, - "/fs/diff": { - "get": { - "description": "Produce a unified diff between a version snapshot and current data", + "/mcp/containers/{id}": { + "delete": { "tags": [ - "fs" + "containerd" ], - "summary": "Diff against a version snapshot", + "summary": "Delete MCP container", "parameters": [ { "type": "string", - "description": "Path under data mount", - "name": "path", - "in": "query" - }, - { - "type": "integer", - "description": "Version number", - "name": "version", - "in": "query", + "description": "Container ID", + "name": "id", + "in": "path", "required": true } ], "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.DiffResponse" - } + "204": { + "description": "No Content" }, "400": { "description": "Bad Request", @@ -316,119 +267,29 @@ } } }, - "/fs/list": { - "get": { - "description": "List files under the user data mount", + "/mcp/snapshots": { + "post": { "tags": [ - "fs" + "containerd" ], - "summary": "List directory contents", + "summary": "Create container snapshot", "parameters": [ { - "type": "string", - "description": "Path under data mount", - "name": "path", - "in": "query" - }, - { - "type": "boolean", - "description": "Recursive listing", - "name": "recursive", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ListResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/fs/read": { - "get": { - "description": "Read a file under the user data mount", - "tags": [ - "fs" - ], - "summary": "Read file content", - "parameters": [ - { - "type": "string", - "description": "Path under data mount", - "name": "path", - "in": "query" - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/handlers.ReadResponse" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" - } - } - } - } - }, - "/fs/write_atomic": { - "put": { - "description": "Atomically replace a file under the user data mount", - "tags": [ - "fs" - ], - "summary": "Write file atomically", - "parameters": [ - { - "description": "Write payload", + "description": "Create snapshot payload", "name": "payload", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/handlers.WriteAtomicRequest" + "$ref": "#/definitions/handlers.CreateSnapshotRequest" } } ], "responses": { - "204": { - "description": "No Content" + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.CreateSnapshotResponse" + } }, "400": { "description": "Bad Request", @@ -1646,45 +1507,56 @@ } } }, - "handlers.ApplyPatchRequest": { + "handlers.CreateContainerRequest": { "type": "object", "properties": { - "patch": { + "container_id": { "type": "string" }, - "path": { + "snapshotter": { "type": "string" } } }, - "handlers.CommitResponse": { + "handlers.CreateContainerResponse": { "type": "object", "properties": { - "created_at": { + "container_id": { "type": "string" }, - "id": { + "image": { "type": "string" }, - "snapshot_id": { + "snapshotter": { "type": "string" }, - "version": { - "type": "integer" + "started": { + "type": "boolean" } } }, - "handlers.DiffResponse": { + "handlers.CreateSnapshotRequest": { "type": "object", "properties": { - "diff": { + "container_id": { "type": "string" }, - "path": { + "snapshot_name": { + "type": "string" + } + } + }, + "handlers.CreateSnapshotResponse": { + "type": "object", + "properties": { + "container_id": { "type": "string" }, - "version": { - "type": "integer" + "snapshot_name": { + "type": "string" + }, + "snapshotter": { + "type": "string" } } }, @@ -1773,40 +1645,6 @@ } } }, - "handlers.FileEntry": { - "type": "object", - "properties": { - "is_dir": { - "type": "boolean" - }, - "mod_time": { - "type": "string" - }, - "mode": { - "type": "integer" - }, - "path": { - "type": "string" - }, - "size": { - "type": "integer" - } - } - }, - "handlers.ListResponse": { - "type": "object", - "properties": { - "entries": { - "type": "array", - "items": { - "$ref": "#/definitions/handlers.FileEntry" - } - }, - "path": { - "type": "string" - } - } - }, "handlers.LoginRequest": { "type": "object", "properties": { @@ -1838,49 +1676,6 @@ } } }, - "handlers.ReadResponse": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "encoding": { - "type": "string" - }, - "mod_time": { - "type": "string" - }, - "mode": { - "type": "integer" - }, - "path": { - "type": "string" - }, - "size": { - "type": "integer" - } - } - }, - "handlers.WriteAtomicRequest": { - "type": "object", - "properties": { - "content": { - "type": "string" - }, - "encoding": { - "type": "string" - }, - "mode": { - "type": "integer" - }, - "mtime": { - "type": "string" - }, - "path": { - "type": "string" - } - } - }, "memory.AddRequest": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 81243027..e58b55b4 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -58,32 +58,39 @@ definitions: total_tokens: type: integer type: object - handlers.ApplyPatchRequest: + handlers.CreateContainerRequest: properties: - patch: + container_id: type: string - path: + snapshotter: type: string type: object - handlers.CommitResponse: + handlers.CreateContainerResponse: properties: - created_at: + container_id: type: string - id: + image: type: string - snapshot_id: + snapshotter: type: string - version: - type: integer + started: + type: boolean type: object - handlers.DiffResponse: + handlers.CreateSnapshotRequest: properties: - diff: + container_id: type: string - path: + snapshot_name: + type: string + type: object + handlers.CreateSnapshotResponse: + properties: + container_id: + type: string + snapshot_name: + type: string + snapshotter: type: string - version: - type: integer type: object handlers.EmbeddingsInput: properties: @@ -140,28 +147,6 @@ definitions: message: type: string type: object - handlers.FileEntry: - properties: - is_dir: - type: boolean - mod_time: - type: string - mode: - type: integer - path: - type: string - size: - type: integer - type: object - handlers.ListResponse: - properties: - entries: - items: - $ref: '#/definitions/handlers.FileEntry' - type: array - path: - type: string - type: object handlers.LoginRequest: properties: password: @@ -182,34 +167,6 @@ definitions: username: type: string type: object - handlers.ReadResponse: - properties: - content: - type: string - encoding: - type: string - mod_time: - type: string - mode: - type: integer - path: - type: string - size: - type: integer - type: object - handlers.WriteAtomicRequest: - properties: - content: - type: string - encoding: - type: string - mode: - type: integer - mtime: - type: string - path: - type: string - type: object memory.AddRequest: properties: agent_id: @@ -638,16 +595,39 @@ paths: summary: Create embeddings tags: - embeddings - /fs/apply_patch: + /mcp/containers: post: - description: Apply a unified diff patch to a file under the user data mount parameters: - - description: Patch payload + - description: Create container payload in: body name: payload required: true schema: - $ref: '#/definitions/handlers.ApplyPatchRequest' + $ref: '#/definitions/handlers.CreateContainerRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.CreateContainerResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Create and start MCP container + tags: + - containerd + /mcp/containers/{id}: + delete: + parameters: + - description: Container ID + in: path + name: id + required: true + type: string responses: "204": description: No Content @@ -663,138 +643,23 @@ paths: description: Internal Server Error schema: $ref: '#/definitions/handlers.ErrorResponse' - summary: Apply unified diff patch + summary: Delete MCP container tags: - - fs - /fs/commit: + - containerd + /mcp/snapshots: post: - description: Create a new version snapshot for the user container - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.CommitResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "404": - description: Not Found - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Commit a filesystem snapshot - tags: - - fs - /fs/diff: - get: - description: Produce a unified diff between a version snapshot and current data parameters: - - description: Path under data mount - in: query - name: path - type: string - - description: Version number - in: query - name: version - required: true - type: integer - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.DiffResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "404": - description: Not Found - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Diff against a version snapshot - tags: - - fs - /fs/list: - get: - description: List files under the user data mount - parameters: - - description: Path under data mount - in: query - name: path - type: string - - description: Recursive listing - in: query - name: recursive - type: boolean - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.ListResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "404": - description: Not Found - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: List directory contents - tags: - - fs - /fs/read: - get: - description: Read a file under the user data mount - parameters: - - description: Path under data mount - in: query - name: path - type: string - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.ReadResponse' - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "404": - description: Not Found - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Read file content - tags: - - fs - /fs/write_atomic: - put: - description: Atomically replace a file under the user data mount - parameters: - - description: Write payload + - description: Create snapshot payload in: body name: payload required: true schema: - $ref: '#/definitions/handlers.WriteAtomicRequest' + $ref: '#/definitions/handlers.CreateSnapshotRequest' responses: - "204": - description: No Content + "200": + description: OK + schema: + $ref: '#/definitions/handlers.CreateSnapshotResponse' "400": description: Bad Request schema: @@ -807,9 +672,9 @@ paths: description: Internal Server Error schema: $ref: '#/definitions/handlers.ErrorResponse' - summary: Write file atomically + summary: Create container snapshot tags: - - fs + - containerd /memory/add: post: description: Add memory for a user via memory diff --git a/go.mod b/go.mod index 9bce88c1..65d518f9 100644 --- a/go.mod +++ b/go.mod @@ -7,14 +7,14 @@ require ( github.com/containerd/containerd/api v1.10.0 github.com/containerd/containerd/v2 v2.2.1 github.com/containerd/errdefs v1.0.0 - github.com/cyphar/filepath-securejoin v0.5.1 + github.com/firebase/genkit/go v1.4.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.8.0 github.com/labstack/echo-jwt/v4 v4.4.0 github.com/labstack/echo/v4 v4.15.0 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/opencontainers/runtime-spec v1.3.0 - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 github.com/qdrant/go-client v1.16.2 github.com/stretchr/testify v1.11.1 github.com/swaggo/swag v1.16.6 @@ -37,10 +37,10 @@ require ( github.com/containerd/plugin v1.0.0 // indirect github.com/containerd/ttrpc v1.2.7 // indirect github.com/containerd/typeurl/v2 v2.2.3 // indirect + github.com/cyphar/filepath-securejoin v0.5.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/firebase/genkit/go v1.4.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.22.4 // indirect @@ -58,6 +58,7 @@ require ( github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 // indirect github.com/google/go-cmp v0.7.0 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -78,6 +79,7 @@ require ( github.com/opencontainers/image-spec v1.1.1 // indirect github.com/opencontainers/selinux v1.13.1 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect @@ -96,6 +98,7 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/mod v0.32.0 // indirect golang.org/x/net v0.49.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect diff --git a/go.sum b/go.sum index ab237c90..2fcf6eac 100644 --- a/go.sum +++ b/go.sum @@ -70,7 +70,7 @@ github.com/go-openapi/jsonreference v0.21.4 h1:24qaE2y9bx/q3uRK/qN+TDwbok1NhbSmG github.com/go-openapi/jsonreference v0.21.4/go.mod h1:rIENPTjDbLpzQmQWCj5kKj3ZlmEh+EFVbz3RTUh30/4= github.com/go-openapi/spec v0.22.3 h1:qRSmj6Smz2rEBxMnLRBMeBWxbbOvuOoElvSvObIgwQc= github.com/go-openapi/spec v0.22.3/go.mod h1:iIImLODL2loCh3Vnox8TY2YWYJZjMAKYyLH2Mu8lOZs= -github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= +github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU= github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4= github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU= github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= @@ -123,6 +123,8 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -170,6 +172,8 @@ github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs= github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g= github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -237,6 +241,8 @@ go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2W go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -264,6 +270,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/internal/chat/factory.go b/internal/chat/factory.go new file mode 100644 index 00000000..ce80870f --- /dev/null +++ b/internal/chat/factory.go @@ -0,0 +1,39 @@ +package chat + +import ( + "fmt" + "strings" + "time" + + dbsqlc "github.com/memohai/memoh/internal/db/sqlc" +) + +// CreateProvider creates a chat provider instance. +func CreateProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (Provider, error) { + clientType := strings.ToLower(strings.TrimSpace(provider.ClientType)) + if timeout <= 0 { + timeout = 30 * time.Second + } + + switch clientType { + case ProviderOpenAI, ProviderOpenAICompat: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, fmt.Errorf("openai api key is required") + } + return NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout) + case ProviderAnthropic: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, fmt.Errorf("anthropic api key is required") + } + return NewAnthropicProvider(provider.ApiKey, timeout) + case ProviderGoogle: + if strings.TrimSpace(provider.ApiKey) == "" { + return nil, fmt.Errorf("google api key is required") + } + return NewGoogleProvider(provider.ApiKey, timeout) + case ProviderOllama: + return NewOllamaProvider(provider.BaseUrl, timeout) + default: + return nil, fmt.Errorf("unsupported provider type: %s", clientType) + } +} diff --git a/internal/containerd/service.go b/internal/containerd/service.go index 2f2c9c52..481beed2 100644 --- a/internal/containerd/service.go +++ b/internal/containerd/service.go @@ -6,6 +6,7 @@ import ( "fmt" "runtime" "syscall" + "strings" "time" tasksv1 "github.com/containerd/containerd/api/services/tasks/v1" @@ -16,6 +17,7 @@ import ( "github.com/containerd/containerd/v2/defaults" "github.com/containerd/containerd/v2/pkg/cio" "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/containerd/errdefs" "github.com/containerd/containerd/v2/pkg/oci" "github.com/opencontainers/runtime-spec/specs-go" ) @@ -179,13 +181,21 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine } ctx = s.withNamespace(ctx) - pullOpts := &PullImageOptions{ - Unpack: true, - Snapshotter: req.Snapshotter, - } - image, err := s.PullImage(ctx, req.ImageRef, pullOpts) + image, err := s.getImageWithFallback(ctx, req.ImageRef) if err != nil { - return nil, err + pullOpts := &PullImageOptions{ + Unpack: true, + Snapshotter: req.Snapshotter, + } + image, err = s.PullImage(ctx, req.ImageRef, pullOpts) + if err != nil { + return nil, err + } + } + if req.Snapshotter != "" { + if err := image.Unpack(ctx, req.Snapshotter); err != nil && !errdefs.IsAlreadyExists(err) { + return nil, err + } } snapshotID := req.SnapshotID @@ -224,6 +234,36 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine return s.client.NewContainer(ctx, req.ID, containerOpts...) } +func (s *DefaultService) getImageWithFallback(ctx context.Context, ref string) (containerd.Image, error) { + image, err := s.GetImage(ctx, ref) + if err == nil { + return image, nil + } + if strings.HasPrefix(ref, "docker.io/library/") { + alt := strings.TrimPrefix(ref, "docker.io/library/") + image, altErr := s.GetImage(ctx, alt) + if altErr == nil { + return image, nil + } + } + images, listErr := s.ListImages(ctx) + if listErr == nil { + for _, img := range images { + name := img.Name() + if name == ref || strings.HasSuffix(ref, "/"+name) || strings.HasSuffix(name, "/"+ref) { + return img, nil + } + if strings.HasPrefix(ref, "docker.io/library/") { + alt := strings.TrimPrefix(ref, "docker.io/library/") + if name == alt || strings.HasSuffix(name, "/"+alt) { + return img, nil + } + } + } + } + return nil, err +} + func (s *DefaultService) GetContainer(ctx context.Context, id string) (containerd.Container, error) { if id == "" { return nil, ErrInvalidArgument @@ -580,3 +620,4 @@ func (s *DefaultService) SnapshotMounts(ctx context.Context, snapshotter, key st func (s *DefaultService) withNamespace(ctx context.Context) context.Context { return namespaces.WithNamespace(ctx, s.namespace) } + diff --git a/internal/embeddings/bootstrap.go b/internal/embeddings/bootstrap.go new file mode 100644 index 00000000..a3e3b281 --- /dev/null +++ b/internal/embeddings/bootstrap.go @@ -0,0 +1,61 @@ +package embeddings + +import ( + "context" + + "github.com/memohai/memoh/internal/models" +) + +// ResolverTextEmbedder adapts Resolver to the Embedder interface for text embeddings. +type ResolverTextEmbedder struct { + Resolver *Resolver + ModelID string + Dims int +} + +func (e *ResolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) { + result, err := e.Resolver.Embed(ctx, Request{ + Type: TypeText, + Model: e.ModelID, + Input: Input{Text: input}, + }) + if err != nil { + return nil, err + } + return result.Embedding, nil +} + +func (e *ResolverTextEmbedder) Dimensions() int { + return e.Dims +} + +// CollectEmbeddingVectors gathers embedding model dimensions and defaults. +func CollectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) { + candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding) + if err != nil { + return nil, models.GetResponse{}, models.GetResponse{}, false, err + } + vectors := map[string]int{} + var textModel models.GetResponse + var multimodalModel models.GetResponse + for _, model := range candidates { + if model.Dimensions > 0 && model.ModelID != "" { + vectors[model.ModelID] = model.Dimensions + } + if model.IsMultimodal { + if multimodalModel.ModelID == "" { + multimodalModel = model + } + continue + } + if textModel.ModelID == "" { + textModel = model + } + } + + hasTextModel := textModel.ModelID != "" + hasMultimodalModel := multimodalModel.ModelID != "" + hasAnyModel := hasTextModel || hasMultimodalModel + + return vectors, textModel, multimodalModel, hasAnyModel, nil +} diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go new file mode 100644 index 00000000..5f623534 --- /dev/null +++ b/internal/handlers/containerd.go @@ -0,0 +1,188 @@ +package handlers + +import ( + "net/http" + "strings" + "time" + + "github.com/containerd/errdefs" + "github.com/containerd/containerd/v2/pkg/namespaces" + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/config" + ctr "github.com/memohai/memoh/internal/containerd" +) + +type ContainerdHandler struct { + service ctr.Service + cfg config.MCPConfig + namespace string +} + +type CreateContainerRequest struct { + ContainerID string `json:"container_id"` + Image string `json:"image,omitempty"` + Snapshotter string `json:"snapshotter,omitempty"` +} + +type CreateContainerResponse struct { + ContainerID string `json:"container_id"` + Image string `json:"image"` + Snapshotter string `json:"snapshotter"` + Started bool `json:"started"` +} + +type CreateSnapshotRequest struct { + ContainerID string `json:"container_id"` + SnapshotName string `json:"snapshot_name"` +} + +type CreateSnapshotResponse struct { + ContainerID string `json:"container_id"` + SnapshotName string `json:"snapshot_name"` + Snapshotter string `json:"snapshotter"` +} + +func NewContainerdHandler(service ctr.Service, cfg config.MCPConfig, namespace string) *ContainerdHandler { + return &ContainerdHandler{ + service: service, + cfg: cfg, + namespace: namespace, + } +} + +func (h *ContainerdHandler) Register(e *echo.Echo) { + group := e.Group("/mcp") + group.POST("/containers", h.CreateContainer) + group.DELETE("/containers/:id", h.DeleteContainer) + group.POST("/snapshots", h.CreateSnapshot) +} + +// CreateContainer godoc +// @Summary Create and start MCP container +// @Tags containerd +// @Param payload body CreateContainerRequest true "Create container payload" +// @Success 200 {object} CreateContainerResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /mcp/containers [post] +func (h *ContainerdHandler) CreateContainer(c echo.Context) error { + var req CreateContainerRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if strings.TrimSpace(req.ContainerID) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "container_id is required") + } + + image := strings.TrimSpace(req.Image) + if image == "" { + image = h.cfg.BusyboxImage + } + if image == "" { + image = config.DefaultBusyboxImg + } + snapshotter := strings.TrimSpace(req.Snapshotter) + if snapshotter == "" { + snapshotter = h.cfg.Snapshotter + } + if snapshotter == "" { + snapshotter = "overlayfs" + } + + _, err := h.service.CreateContainer(c.Request().Context(), ctr.CreateContainerRequest{ + ID: req.ContainerID, + ImageRef: image, + Snapshotter: snapshotter, + }) + if err != nil && !errdefs.IsAlreadyExists(err) { + return echo.NewHTTPError(http.StatusInternalServerError, "snapshotter="+snapshotter+" image="+image+" err="+err.Error()) + } + + started := false + if _, err := h.service.StartTask(c.Request().Context(), req.ContainerID, &ctr.StartTaskOptions{ + UseStdio: false, + }); err == nil { + started = true + } + + return c.JSON(http.StatusOK, CreateContainerResponse{ + ContainerID: req.ContainerID, + Image: image, + Snapshotter: snapshotter, + Started: started, + }) +} + +// DeleteContainer godoc +// @Summary Delete MCP container +// @Tags containerd +// @Param id path string true "Container ID" +// @Success 204 +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /mcp/containers/{id} [delete] +func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { + containerID := strings.TrimSpace(c.Param("id")) + if containerID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "container id is required") + } + _ = h.service.DeleteTask(c.Request().Context(), containerID, &ctr.DeleteTaskOptions{Force: true}) + if err := h.service.DeleteContainer(c.Request().Context(), containerID, &ctr.DeleteContainerOptions{ + CleanupSnapshot: true, + }); err != nil { + if errdefs.IsNotFound(err) { + return echo.NewHTTPError(http.StatusNotFound, "container not found") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// CreateSnapshot godoc +// @Summary Create container snapshot +// @Tags containerd +// @Param payload body CreateSnapshotRequest true "Create snapshot payload" +// @Success 200 {object} CreateSnapshotResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /mcp/snapshots [post] +func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { + var req CreateSnapshotRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if strings.TrimSpace(req.ContainerID) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "container_id is required") + } + container, err := h.service.GetContainer(c.Request().Context(), req.ContainerID) + if err != nil { + if errdefs.IsNotFound(err) { + return echo.NewHTTPError(http.StatusNotFound, "container not found") + } + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + ctx := c.Request().Context() + if strings.TrimSpace(h.namespace) != "" { + ctx = namespaces.WithNamespace(ctx, h.namespace) + } + info, err := container.Info(ctx) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + snapshotName := strings.TrimSpace(req.SnapshotName) + if snapshotName == "" { + snapshotName = req.ContainerID + "-" + time.Now().Format("20060102150405") + } + if err := h.service.CommitSnapshot(c.Request().Context(), info.Snapshotter, snapshotName, info.SnapshotKey); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + return c.JSON(http.StatusOK, CreateSnapshotResponse{ + ContainerID: req.ContainerID, + SnapshotName: snapshotName, + Snapshotter: info.Snapshotter, + }) +} diff --git a/internal/handlers/error.go b/internal/handlers/error.go new file mode 100644 index 00000000..d2db6e95 --- /dev/null +++ b/internal/handlers/error.go @@ -0,0 +1,5 @@ +package handlers + +type ErrorResponse struct { + Message string `json:"message"` +} diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go deleted file mode 100644 index 0a2d9999..00000000 --- a/internal/handlers/fs.go +++ /dev/null @@ -1,803 +0,0 @@ -package handlers - -import ( - "bytes" - "context" - "encoding/base64" - "fmt" - "io" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/containerd/containerd/v2/pkg/namespaces" - "github.com/containerd/errdefs" - securejoin "github.com/cyphar/filepath-securejoin" - "github.com/labstack/echo/v4" - "github.com/pmezard/go-difflib/difflib" - - "github.com/memohai/memoh/internal/auth" - "github.com/memohai/memoh/internal/config" - ctr "github.com/memohai/memoh/internal/containerd" - "github.com/memohai/memoh/internal/identity" - "github.com/memohai/memoh/internal/mcp" -) - -type FSHandler struct { - service ctr.Service - manager *mcp.Manager - mcpConfig config.MCPConfig - namespace string -} - -type ErrorResponse struct { - Message string `json:"message"` -} - -type ReadResponse struct { - Path string `json:"path"` - Content string `json:"content"` - Encoding string `json:"encoding"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type FileEntry struct { - Path string `json:"path"` - IsDir bool `json:"is_dir"` - Size int64 `json:"size"` - Mode uint32 `json:"mode"` - ModTime time.Time `json:"mod_time"` -} - -type ListResponse struct { - Path string `json:"path"` - Entries []FileEntry `json:"entries"` -} - -type WriteAtomicRequest struct { - Path string `json:"path"` - Content string `json:"content"` - Encoding string `json:"encoding"` - Mode *uint32 `json:"mode,omitempty"` - ModTime *time.Time `json:"mtime,omitempty"` -} - -type ApplyPatchRequest struct { - Path string `json:"path"` - Patch string `json:"patch"` -} - -type CommitResponse struct { - ID string `json:"id"` - Version int `json:"version"` - SnapshotID string `json:"snapshot_id"` - CreatedAt time.Time `json:"created_at"` -} - -type DiffResponse struct { - Path string `json:"path"` - Version int `json:"version"` - Diff string `json:"diff"` -} - -func NewFSHandler(service ctr.Service, manager *mcp.Manager, mcpConfig config.MCPConfig, namespace string) *FSHandler { - if namespace == "" { - namespace = config.DefaultNamespace - } - return &FSHandler{ - service: service, - manager: manager, - mcpConfig: mcpConfig, - namespace: namespace, - } -} - -func (h *FSHandler) Register(e *echo.Echo) { - group := e.Group("/fs") - group.GET("/read", h.Read) - group.GET("/list", h.List) - group.PUT("/write_atomic", h.WriteAtomic) - group.POST("/apply_patch", h.ApplyPatch) - group.POST("/commit", h.Commit) - group.GET("/diff", h.Diff) -} - -// Read godoc -// @Summary Read file content -// @Description Read a file under the user data mount -// @Tags fs -// @Param path query string false "Path under data mount" -// @Success 200 {object} ReadResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /fs/read [get] -func (h *FSHandler) Read(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - - ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace) - mount, err := h.mountUser(ctx, userID) - if err != nil { - return err - } - defer mount.Unmount() - - containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - hostPath, err := resolveHostPath(mount.Dir, containerPath) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - info, err := os.Stat(hostPath) - if err != nil { - if os.IsNotExist(err) { - return echo.NewHTTPError(http.StatusNotFound, "file not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if info.IsDir() { - return echo.NewHTTPError(http.StatusBadRequest, "path is a directory") - } - - data, err := os.ReadFile(hostPath) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - return c.JSON(http.StatusOK, ReadResponse{ - Path: containerPath, - Content: base64.StdEncoding.EncodeToString(data), - Encoding: "base64", - Size: info.Size(), - Mode: uint32(info.Mode().Perm()), - ModTime: info.ModTime(), - }) -} - -// List godoc -// @Summary List directory contents -// @Description List files under the user data mount -// @Tags fs -// @Param path query string false "Path under data mount" -// @Param recursive query bool false "Recursive listing" -// @Success 200 {object} ListResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /fs/list [get] -func (h *FSHandler) List(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - - ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace) - mount, err := h.mountUser(ctx, userID) - if err != nil { - return err - } - defer mount.Unmount() - - containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - hostPath, err := resolveHostPath(mount.Dir, containerPath) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - info, err := os.Stat(hostPath) - if err != nil { - if os.IsNotExist(err) { - return echo.NewHTTPError(http.StatusNotFound, "path not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if !info.IsDir() { - return echo.NewHTTPError(http.StatusBadRequest, "path is not a directory") - } - - recursive := strings.EqualFold(c.QueryParam("recursive"), "true") - entries := []FileEntry{} - if recursive { - err = filepath.WalkDir(hostPath, func(p string, d os.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr - } - if p == hostPath { - return nil - } - entryInfo, err := d.Info() - if err != nil { - return err - } - containerEntry, err := containerPathForHost(mount.Dir, p) - if err != nil { - return err - } - entries = append(entries, FileEntry{ - Path: containerEntry, - IsDir: d.IsDir(), - Size: entryInfo.Size(), - Mode: uint32(entryInfo.Mode().Perm()), - ModTime: entryInfo.ModTime(), - }) - return nil - }) - } else { - dirEntries, err := os.ReadDir(hostPath) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - for _, entry := range dirEntries { - entryInfo, err := entry.Info() - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - entryPath := filepath.Join(hostPath, entry.Name()) - containerEntry, err := containerPathForHost(mount.Dir, entryPath) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - entries = append(entries, FileEntry{ - Path: containerEntry, - IsDir: entry.IsDir(), - Size: entryInfo.Size(), - Mode: uint32(entryInfo.Mode().Perm()), - ModTime: entryInfo.ModTime(), - }) - } - } - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - return c.JSON(http.StatusOK, ListResponse{ - Path: containerPath, - Entries: entries, - }) -} - -// WriteAtomic godoc -// @Summary Write file atomically -// @Description Atomically replace a file under the user data mount -// @Tags fs -// @Param payload body WriteAtomicRequest true "Write payload" -// @Success 204 -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /fs/write_atomic [put] -func (h *FSHandler) WriteAtomic(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - - var req WriteAtomicRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if req.Path == "" { - return echo.NewHTTPError(http.StatusBadRequest, "path is required") - } - - ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace) - mount, err := h.mountUser(ctx, userID) - if err != nil { - return err - } - defer mount.Unmount() - - containerPath, err := resolveContainerPath(h.dataMount(), req.Path) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - hostPath, err := resolveHostPath(mount.Dir, containerPath) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - data, err := decodeContent(req.Content, req.Encoding) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - mode := os.FileMode(0o644) - if req.Mode != nil { - mode = os.FileMode(*req.Mode) - } - - if err := writeFileAtomic(hostPath, data, mode, req.ModTime); err != nil { - if os.IsNotExist(err) { - return echo.NewHTTPError(http.StatusNotFound, "path not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - return c.NoContent(http.StatusNoContent) -} - -// ApplyPatch godoc -// @Summary Apply unified diff patch -// @Description Apply a unified diff patch to a file under the user data mount -// @Tags fs -// @Param payload body ApplyPatchRequest true "Patch payload" -// @Success 204 -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /fs/apply_patch [post] -func (h *FSHandler) ApplyPatch(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - - var req ApplyPatchRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - if req.Path == "" || req.Patch == "" { - return echo.NewHTTPError(http.StatusBadRequest, "path and patch are required") - } - - ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace) - mount, err := h.mountUser(ctx, userID) - if err != nil { - return err - } - defer mount.Unmount() - - containerPath, err := resolveContainerPath(h.dataMount(), req.Path) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - hostPath, err := resolveHostPath(mount.Dir, containerPath) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - orig, err := os.ReadFile(hostPath) - if err != nil { - if os.IsNotExist(err) { - return echo.NewHTTPError(http.StatusNotFound, "file not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - updated, err := applyUnifiedPatch(string(orig), req.Patch) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - info, err := os.Stat(hostPath) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - if err := writeFileAtomic(hostPath, []byte(updated), info.Mode().Perm(), nil); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - return c.NoContent(http.StatusNoContent) -} - -// Commit godoc -// @Summary Commit a filesystem snapshot -// @Description Create a new version snapshot for the user container -// @Tags fs -// @Success 200 {object} CommitResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /fs/commit [post] -func (h *FSHandler) Commit(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - - if h.manager == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "manager not configured") - } - - ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace) - if err := h.ensureUserContainer(ctx, userID); err != nil { - return err - } - - info, err := h.manager.CreateVersion(ctx, userID) - if err != nil { - if errdefs.IsNotFound(err) { - return echo.NewHTTPError(http.StatusNotFound, "container not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - return c.JSON(http.StatusOK, CommitResponse{ - ID: info.ID, - Version: info.Version, - SnapshotID: info.SnapshotID, - CreatedAt: info.CreatedAt, - }) -} - -// Diff godoc -// @Summary Diff against a version snapshot -// @Description Produce a unified diff between a version snapshot and current data -// @Tags fs -// @Param path query string false "Path under data mount" -// @Param version query int true "Version number" -// @Success 200 {object} DiffResponse -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /fs/diff [get] -func (h *FSHandler) Diff(c echo.Context) error { - userID, err := h.requireUserID(c) - if err != nil { - return err - } - if h.manager == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "manager not configured") - } - - versionStr := c.QueryParam("version") - if versionStr == "" { - return echo.NewHTTPError(http.StatusBadRequest, "version is required") - } - version, err := strconv.Atoi(versionStr) - if err != nil || version <= 0 { - return echo.NewHTTPError(http.StatusBadRequest, "invalid version") - } - - containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace) - mount, err := h.mountUser(ctx, userID) - if err != nil { - return err - } - defer mount.Unmount() - - versionSnapshotID, err := h.manager.VersionSnapshotID(ctx, userID, version) - if err != nil { - if errdefs.IsNotFound(err) { - return echo.NewHTTPError(http.StatusNotFound, "version not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - versionDir, versionCleanup, err := ctr.MountSnapshot(ctx, h.service, mount.Info.Snapshotter, versionSnapshotID) - if err != nil { - if errdefs.IsNotFound(err) { - return echo.NewHTTPError(http.StatusNotFound, "snapshot not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - defer versionCleanup() - - currentHostPath, err := resolveHostPath(mount.Dir, containerPath) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - versionHostPath, err := resolveHostPath(versionDir, containerPath) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - currentContent, err := readFileOrEmpty(currentHostPath) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - versionContent, err := readFileOrEmpty(versionHostPath) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - diffText, err := unifiedDiff(containerPath, versionContent, currentContent) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - return c.JSON(http.StatusOK, DiffResponse{ - Path: containerPath, - Version: version, - Diff: diffText, - }) -} - -func (h *FSHandler) dataMount() string { - if h.mcpConfig.DataMount == "" { - return config.DefaultDataMount - } - return h.mcpConfig.DataMount -} - -func (h *FSHandler) requireUserID(c echo.Context) (string, error) { - userID, err := auth.UserIDFromContext(c) - if err != nil { - return "", err - } - if err := identity.ValidateUserID(userID); err != nil { - return "", echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - return userID, nil -} - -func (h *FSHandler) mountUser(ctx context.Context, userID string) (*ctr.MountedSnapshot, error) { - containerID := mcp.ContainerPrefix + userID - mount, err := ctr.MountContainerSnapshot(ctx, h.service, containerID) - if err != nil { - if errdefs.IsNotFound(err) { - return nil, echo.NewHTTPError(http.StatusNotFound, "container not found") - } - return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if label, ok := mount.Info.Labels[mcp.UserLabelKey]; !ok || label != userID { - _ = mount.Unmount() - return nil, echo.NewHTTPError(http.StatusForbidden, "user mismatch") - } - return mount, nil -} - -func (h *FSHandler) ensureUserContainer(ctx context.Context, userID string) error { - containerID := mcp.ContainerPrefix + userID - container, err := h.service.GetContainer(ctx, containerID) - if err != nil { - if errdefs.IsNotFound(err) { - return echo.NewHTTPError(http.StatusNotFound, "container not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - info, err := container.Info(ctx) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if label, ok := info.Labels[mcp.UserLabelKey]; !ok || label != userID { - return echo.NewHTTPError(http.StatusForbidden, "user mismatch") - } - return nil -} - -func resolveContainerPath(dataMount, requestPath string) (string, error) { - mountPath := path.Clean(dataMount) - if mountPath == "." || !strings.HasPrefix(mountPath, "/") { - return "", fmt.Errorf("data mount must be absolute") - } - - if requestPath == "" { - return mountPath, nil - } - - reqClean := path.Clean(requestPath) - if path.IsAbs(reqClean) { - if !pathWithin(reqClean, mountPath) { - return "", fmt.Errorf("path outside data mount") - } - return reqClean, nil - } - - return path.Join(mountPath, reqClean), nil -} - -func pathWithin(target, base string) bool { - if base == "/" { - return strings.HasPrefix(target, "/") - } - if target == base { - return true - } - if strings.HasPrefix(target, base) { - return len(target) > len(base) && target[len(base)] == '/' - } - return false -} - -func resolveHostPath(mountDir, containerPath string) (string, error) { - rel := strings.TrimPrefix(containerPath, "/") - return securejoin.SecureJoin(mountDir, rel) -} - -func containerPathForHost(mountDir, hostPath string) (string, error) { - rel, err := filepath.Rel(mountDir, hostPath) - if err != nil { - return "", err - } - if strings.HasPrefix(rel, "..") { - return "", fmt.Errorf("path escapes mount") - } - return "/" + filepath.ToSlash(rel), nil -} - -func decodeContent(content, encoding string) ([]byte, error) { - switch strings.ToLower(encoding) { - case "", "plain": - return []byte(content), nil - case "base64": - return base64.StdEncoding.DecodeString(content) - default: - return nil, fmt.Errorf("unsupported encoding") - } -} - -func writeFileAtomic(targetPath string, data []byte, mode os.FileMode, modTime *time.Time) error { - dir := filepath.Dir(targetPath) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err - } - - tmp, err := os.CreateTemp(dir, ".tmp-*") - if err != nil { - return err - } - tmpName := tmp.Name() - defer os.Remove(tmpName) - - if _, err := io.Copy(tmp, bytes.NewReader(data)); err != nil { - _ = tmp.Close() - return err - } - if err := tmp.Sync(); err != nil { - _ = tmp.Close() - return err - } - if err := tmp.Chmod(mode); err != nil { - _ = tmp.Close() - return err - } - if err := tmp.Close(); err != nil { - return err - } - if modTime != nil { - if err := os.Chtimes(tmpName, *modTime, *modTime); err != nil { - return err - } - } - if err := os.Rename(tmpName, targetPath); err != nil { - return err - } - if modTime != nil { - _ = os.Chtimes(targetPath, *modTime, *modTime) - } - return nil -} - -func applyUnifiedPatch(original, patch string) (string, error) { - lines := strings.Split(original, "\n") - out := make([]string, 0, len(lines)) - index := 0 - patchLines := strings.Split(patch, "\n") - hunksApplied := 0 - - for i := 0; i < len(patchLines); i++ { - line := patchLines[i] - if !strings.HasPrefix(line, "@@") { - continue - } - - origStart, err := parseUnifiedHunkHeader(line) - if err != nil { - return "", err - } - origStart-- - if origStart < 0 { - origStart = 0 - } - if origStart > len(lines) { - return "", fmt.Errorf("patch out of range") - } - - out = append(out, lines[index:origStart]...) - index = origStart - hunksApplied++ - - for i+1 < len(patchLines) { - next := patchLines[i+1] - if strings.HasPrefix(next, "@@") { - break - } - i++ - - if next == "" { - if i == len(patchLines)-1 { - break - } - return "", fmt.Errorf("invalid patch line") - } - if next[0] == '\\' { - continue - } - if len(next) < 1 { - return "", fmt.Errorf("invalid patch line") - } - op := next[0] - text := next[1:] - switch op { - case ' ': - if index >= len(lines) || lines[index] != text { - return "", fmt.Errorf("patch context mismatch") - } - out = append(out, text) - index++ - case '-': - if index >= len(lines) || lines[index] != text { - return "", fmt.Errorf("patch delete mismatch") - } - index++ - case '+': - out = append(out, text) - default: - return "", fmt.Errorf("invalid patch operation") - } - } - } - if hunksApplied == 0 { - return "", fmt.Errorf("patch contains no hunks") - } - - out = append(out, lines[index:]...) - return strings.Join(out, "\n"), nil -} - -func parseUnifiedHunkHeader(header string) (int, error) { - trimmed := strings.TrimPrefix(header, "@@") - trimmed = strings.TrimSpace(trimmed) - if !strings.HasPrefix(trimmed, "-") { - return 0, fmt.Errorf("invalid hunk header") - } - parts := strings.SplitN(trimmed, " ", 2) - if len(parts) < 2 { - return 0, fmt.Errorf("invalid hunk header") - } - - origPart := strings.TrimPrefix(parts[0], "-") - origFields := strings.SplitN(origPart, ",", 2) - origStart, err := strconv.Atoi(origFields[0]) - if err != nil { - return 0, fmt.Errorf("invalid hunk header") - } - return origStart, nil -} - -func readFileOrEmpty(path string) (string, error) { - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - return "", nil - } - return "", err - } - return string(data), nil -} - -func unifiedDiff(containerPath, oldContent, newContent string) (string, error) { - diffText, err := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ - A: strings.Split(oldContent, "\n"), - B: strings.Split(newContent, "\n"), - FromFile: "a" + containerPath, - ToFile: "b" + containerPath, - Context: 3, - }) - if err != nil { - return "", err - } - return diffText, nil -} diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go new file mode 100644 index 00000000..a7f85d9a --- /dev/null +++ b/internal/mcp/tools.go @@ -0,0 +1,408 @@ +package mcp + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type EchoInput struct { + Text string `json:"text" jsonschema:"text to echo"` +} + +type EchoOutput struct { + Text string `json:"text" jsonschema:"echoed text"` +} + +type FSReadInput struct { + Path string `json:"path" jsonschema:"relative file path"` +} + +type FSReadOutput struct { + Content string `json:"content" jsonschema:"file content"` +} + +type FSWriteInput struct { + Path string `json:"path" jsonschema:"relative file path"` + Content string `json:"content" jsonschema:"file content"` +} + +type FSWriteOutput struct { + OK bool `json:"ok" jsonschema:"write result"` +} + +type FSListInput struct { + Path string `json:"path" jsonschema:"relative directory path"` + Recursive bool `json:"recursive" jsonschema:"recursive listing"` +} + +type FSFileEntry struct { + Path string `json:"path" jsonschema:"relative entry path"` + IsDir bool `json:"is_dir" jsonschema:"is directory"` + Size int64 `json:"size" jsonschema:"entry size"` + Mode uint32 `json:"mode" jsonschema:"file mode"` + ModTime time.Time `json:"mod_time" jsonschema:"modification time"` +} + +type FSListOutput struct { + Path string `json:"path" jsonschema:"listed path"` + Entries []FSFileEntry `json:"entries" jsonschema:"entries"` +} + +type FSStatInput struct { + Path string `json:"path" jsonschema:"relative path"` +} + +type FSStatOutput struct { + Entry FSFileEntry `json:"entry" jsonschema:"entry"` +} + +type FSDeleteInput struct { + Path string `json:"path" jsonschema:"relative path"` +} + +type FSDeleteOutput struct { + OK bool `json:"ok" jsonschema:"delete result"` +} + +type FSApplyPatchInput struct { + Path string `json:"path" jsonschema:"relative file path"` + Patch string `json:"patch" jsonschema:"unified diff patch"` +} + +type FSApplyPatchOutput struct { + OK bool `json:"ok" jsonschema:"apply result"` +} + +func RegisterTools(server *sdkmcp.Server) { + sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "echo", Description: "echo input text"}, echoTool) + sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.read", Description: "read file content"}, fsReadTool) + sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.write", Description: "write file content"}, fsWriteTool) + sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.list", Description: "list directory entries"}, fsListTool) + sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.stat", Description: "stat file or directory"}, fsStatTool) + sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.delete", Description: "delete file or directory"}, fsDeleteTool) + sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.apply_patch", Description: "apply unified diff patch"}, fsApplyPatchTool) +} + +func echoTool(ctx context.Context, req *sdkmcp.CallToolRequest, input EchoInput) ( + *sdkmcp.CallToolResult, + EchoOutput, + error, +) { + return nil, EchoOutput{Text: input.Text}, nil +} + +func fsReadTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSReadInput) ( + *sdkmcp.CallToolResult, + FSReadOutput, + error, +) { + root := dataRoot() + target, err := resolvePath(root, input.Path) + if err != nil { + return nil, FSReadOutput{}, err + } + data, err := os.ReadFile(target) + if err != nil { + return nil, FSReadOutput{}, err + } + return nil, FSReadOutput{Content: string(data)}, nil +} + +func fsWriteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSWriteInput) ( + *sdkmcp.CallToolResult, + FSWriteOutput, + error, +) { + root := dataRoot() + target, err := resolvePath(root, input.Path) + if err != nil { + return nil, FSWriteOutput{}, err + } + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return nil, FSWriteOutput{}, err + } + if err := os.WriteFile(target, []byte(input.Content), 0o644); err != nil { + return nil, FSWriteOutput{}, err + } + return nil, FSWriteOutput{OK: true}, nil +} + +func fsListTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSListInput) ( + *sdkmcp.CallToolResult, + FSListOutput, + error, +) { + root := dataRoot() + target, err := resolvePathAllowRoot(root, input.Path) + if err != nil { + return nil, FSListOutput{}, err + } + info, err := os.Stat(target) + if err != nil { + return nil, FSListOutput{}, err + } + if !info.IsDir() { + return nil, FSListOutput{}, fmt.Errorf("path is not a directory") + } + + entries := []FSFileEntry{} + if input.Recursive { + err = filepath.WalkDir(target, func(p string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if p == target { + return nil + } + entryInfo, err := d.Info() + if err != nil { + return err + } + entry, err := entryForPath(root, p, entryInfo) + if err != nil { + return err + } + entries = append(entries, entry) + return nil + }) + } else { + dirEntries, err := os.ReadDir(target) + if err != nil { + return nil, FSListOutput{}, err + } + for _, entry := range dirEntries { + entryInfo, err := entry.Info() + if err != nil { + return nil, FSListOutput{}, err + } + fullPath := filepath.Join(target, entry.Name()) + fileEntry, err := entryForPath(root, fullPath, entryInfo) + if err != nil { + return nil, FSListOutput{}, err + } + entries = append(entries, fileEntry) + } + } + if err != nil { + return nil, FSListOutput{}, err + } + + listedPath := strings.TrimSpace(input.Path) + if listedPath == "" { + listedPath = "." + } + return nil, FSListOutput{Path: listedPath, Entries: entries}, nil +} + +func fsStatTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSStatInput) ( + *sdkmcp.CallToolResult, + FSStatOutput, + error, +) { + root := dataRoot() + target, err := resolvePathAllowRoot(root, input.Path) + if err != nil { + return nil, FSStatOutput{}, err + } + info, err := os.Stat(target) + if err != nil { + return nil, FSStatOutput{}, err + } + entry, err := entryForPath(root, target, info) + if err != nil { + return nil, FSStatOutput{}, err + } + return nil, FSStatOutput{Entry: entry}, nil +} + +func fsDeleteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSDeleteInput) ( + *sdkmcp.CallToolResult, + FSDeleteOutput, + error, +) { + root := dataRoot() + target, err := resolvePath(root, input.Path) + if err != nil { + return nil, FSDeleteOutput{}, err + } + if err := os.RemoveAll(target); err != nil { + return nil, FSDeleteOutput{}, err + } + return nil, FSDeleteOutput{OK: true}, nil +} + +func fsApplyPatchTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSApplyPatchInput) ( + *sdkmcp.CallToolResult, + FSApplyPatchOutput, + error, +) { + root := dataRoot() + target, err := resolvePath(root, input.Path) + if err != nil { + return nil, FSApplyPatchOutput{}, err + } + orig, err := os.ReadFile(target) + if err != nil { + return nil, FSApplyPatchOutput{}, err + } + updated, err := applyUnifiedPatch(string(orig), input.Patch) + if err != nil { + return nil, FSApplyPatchOutput{}, err + } + info, err := os.Stat(target) + if err != nil { + return nil, FSApplyPatchOutput{}, err + } + if err := os.WriteFile(target, []byte(updated), info.Mode().Perm()); err != nil { + return nil, FSApplyPatchOutput{}, err + } + return nil, FSApplyPatchOutput{OK: true}, nil +} + +func dataRoot() string { + root := strings.TrimSpace(os.Getenv("MCP_DATA_DIR")) + if root == "" { + root = "/data" + } + return root +} + +func resolvePathAllowRoot(root, requestPath string) (string, error) { + if strings.TrimSpace(requestPath) == "" { + return root, nil + } + return resolvePath(root, requestPath) +} + +func resolvePath(root, requestPath string) (string, error) { + clean := filepath.Clean(requestPath) + if clean == "." || clean == "" { + return "", os.ErrInvalid + } + if filepath.IsAbs(clean) || strings.HasPrefix(clean, "..") { + return "", os.ErrInvalid + } + return filepath.Join(root, clean), nil +} + +func entryForPath(root, target string, info os.FileInfo) (FSFileEntry, error) { + rel, err := filepath.Rel(root, target) + if err != nil { + return FSFileEntry{}, err + } + if strings.HasPrefix(rel, "..") { + return FSFileEntry{}, os.ErrInvalid + } + if rel == "." { + rel = "" + } + return FSFileEntry{ + Path: filepath.ToSlash(rel), + IsDir: info.IsDir(), + Size: info.Size(), + Mode: uint32(info.Mode().Perm()), + ModTime: info.ModTime(), + }, nil +} + +func applyUnifiedPatch(original, patch string) (string, error) { + lines := strings.Split(original, "\n") + out := make([]string, 0, len(lines)) + index := 0 + patchLines := strings.Split(patch, "\n") + hunksApplied := 0 + + for i := 0; i < len(patchLines); i++ { + line := patchLines[i] + if !strings.HasPrefix(line, "@@") { + continue + } + + origStart, err := parseUnifiedHunkHeader(line) + if err != nil { + return "", err + } + origStart-- + if origStart < 0 { + origStart = 0 + } + if origStart > len(lines) { + return "", fmt.Errorf("patch out of range") + } + + out = append(out, lines[index:origStart]...) + index = origStart + hunksApplied++ + + for i+1 < len(patchLines) { + next := patchLines[i+1] + if strings.HasPrefix(next, "@@") { + break + } + i++ + + if next == "" { + if i == len(patchLines)-1 { + break + } + return "", fmt.Errorf("invalid patch line") + } + if next[0] == '\\' { + continue + } + op := next[0] + text := next[1:] + switch op { + case ' ': + if index >= len(lines) || lines[index] != text { + return "", fmt.Errorf("patch context mismatch") + } + out = append(out, text) + index++ + case '-': + if index >= len(lines) || lines[index] != text { + return "", fmt.Errorf("patch delete mismatch") + } + index++ + case '+': + out = append(out, text) + default: + return "", fmt.Errorf("invalid patch operation") + } + } + } + if hunksApplied == 0 { + return "", fmt.Errorf("patch contains no hunks") + } + + out = append(out, lines[index:]...) + return strings.Join(out, "\n"), nil +} + +func parseUnifiedHunkHeader(header string) (int, error) { + trimmed := strings.TrimPrefix(header, "@@") + trimmed = strings.TrimSpace(trimmed) + if !strings.HasPrefix(trimmed, "-") { + return 0, fmt.Errorf("invalid hunk header") + } + parts := strings.SplitN(trimmed, " ", 2) + if len(parts) < 2 { + return 0, fmt.Errorf("invalid hunk header") + } + + origPart := strings.TrimPrefix(parts[0], "-") + origFields := strings.SplitN(origPart, ",", 2) + origStart, err := strconv.Atoi(origFields[0]) + if err != nil { + return 0, fmt.Errorf("invalid hunk header") + } + return origStart, nil +} diff --git a/internal/models/bootstrap.go b/internal/models/bootstrap.go new file mode 100644 index 00000000..4c302f7a --- /dev/null +++ b/internal/models/bootstrap.go @@ -0,0 +1,57 @@ +package models + +import ( + "context" + "fmt" + "strings" + + "github.com/memohai/memoh/internal/db/sqlc" +) + +// SelectMemoryModel selects a chat model for memory operations. +func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) { + // First try to get the memory-enabled model. + memoryModel, err := modelsService.GetByEnableAs(ctx, EnableAsMemory) + if err == nil { + provider, err := FetchProviderByID(ctx, queries, memoryModel.LlmProviderID) + if err != nil { + return GetResponse{}, sqlc.LlmProvider{}, err + } + return memoryModel, provider, nil + } + + // Fallback to chat model. + chatModel, err := modelsService.GetByEnableAs(ctx, EnableAsChat) + if err == nil { + provider, err := FetchProviderByID(ctx, queries, chatModel.LlmProviderID) + if err != nil { + return GetResponse{}, sqlc.LlmProvider{}, err + } + return chatModel, provider, nil + } + + // If no enabled models, try to find any chat model. + candidates, err := modelsService.ListByType(ctx, ModelTypeChat) + if err != nil || len(candidates) == 0 { + return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") + } + + selected := candidates[0] + provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID) + if err != nil { + return GetResponse{}, sqlc.LlmProvider{}, err + } + return selected, provider, nil +} + +// FetchProviderByID fetches a provider by ID. +func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) { + if strings.TrimSpace(providerID) == "" { + return sqlc.LlmProvider{}, fmt.Errorf("provider id missing") + } + parsed, err := parseUUID(providerID) + if err != nil { + return sqlc.LlmProvider{}, err + } + return queries.GetLlmProviderByID(ctx, parsed) +} diff --git a/internal/server/server.go b/internal/server/server.go index 2b1b68e7..bfcff7a4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,7 @@ type Server struct { addr string } -func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, fsHandler *handlers.FSHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler) *Server { +func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server { if addr == "" { addr = ":8080" } @@ -29,6 +29,9 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, if path == "/ping" || path == "/api/swagger.json" || path == "/auth/login" { return true } + if strings.HasPrefix(path, "/mcp/") { + return true + } if strings.HasPrefix(path, "/api/docs") { return true } @@ -47,9 +50,6 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, if embeddingsHandler != nil { embeddingsHandler.Register(e) } - if fsHandler != nil { - fsHandler.Register(e) - } if swaggerHandler != nil { swaggerHandler.Register(e) } @@ -62,6 +62,9 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, if modelsHandler != nil { modelsHandler.Register(e) } + if containerdHandler != nil { + containerdHandler.Register(e) + } return &Server{ echo: e,