refact: go mcp tool in containerd

This commit is contained in:
Ran
2026-01-28 04:48:32 +07:00
parent da6a264699
commit bb5482b982
18 changed files with 1046 additions and 1836 deletions
+39
View File
@@ -0,0 +1,39 @@
package chat
import (
"fmt"
"strings"
"time"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
)
// CreateProvider creates a chat provider instance.
func CreateProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (Provider, error) {
clientType := strings.ToLower(strings.TrimSpace(provider.ClientType))
if timeout <= 0 {
timeout = 30 * time.Second
}
switch clientType {
case ProviderOpenAI, ProviderOpenAICompat:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, fmt.Errorf("openai api key is required")
}
return NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout)
case ProviderAnthropic:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, fmt.Errorf("anthropic api key is required")
}
return NewAnthropicProvider(provider.ApiKey, timeout)
case ProviderGoogle:
if strings.TrimSpace(provider.ApiKey) == "" {
return nil, fmt.Errorf("google api key is required")
}
return NewGoogleProvider(provider.ApiKey, timeout)
case ProviderOllama:
return NewOllamaProvider(provider.BaseUrl, timeout)
default:
return nil, fmt.Errorf("unsupported provider type: %s", clientType)
}
}
+47 -6
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"runtime"
"syscall"
"strings"
"time"
tasksv1 "github.com/containerd/containerd/api/services/tasks/v1"
@@ -16,6 +17,7 @@ import (
"github.com/containerd/containerd/v2/defaults"
"github.com/containerd/containerd/v2/pkg/cio"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/errdefs"
"github.com/containerd/containerd/v2/pkg/oci"
"github.com/opencontainers/runtime-spec/specs-go"
)
@@ -179,13 +181,21 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine
}
ctx = s.withNamespace(ctx)
pullOpts := &PullImageOptions{
Unpack: true,
Snapshotter: req.Snapshotter,
}
image, err := s.PullImage(ctx, req.ImageRef, pullOpts)
image, err := s.getImageWithFallback(ctx, req.ImageRef)
if err != nil {
return nil, err
pullOpts := &PullImageOptions{
Unpack: true,
Snapshotter: req.Snapshotter,
}
image, err = s.PullImage(ctx, req.ImageRef, pullOpts)
if err != nil {
return nil, err
}
}
if req.Snapshotter != "" {
if err := image.Unpack(ctx, req.Snapshotter); err != nil && !errdefs.IsAlreadyExists(err) {
return nil, err
}
}
snapshotID := req.SnapshotID
@@ -224,6 +234,36 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine
return s.client.NewContainer(ctx, req.ID, containerOpts...)
}
func (s *DefaultService) getImageWithFallback(ctx context.Context, ref string) (containerd.Image, error) {
image, err := s.GetImage(ctx, ref)
if err == nil {
return image, nil
}
if strings.HasPrefix(ref, "docker.io/library/") {
alt := strings.TrimPrefix(ref, "docker.io/library/")
image, altErr := s.GetImage(ctx, alt)
if altErr == nil {
return image, nil
}
}
images, listErr := s.ListImages(ctx)
if listErr == nil {
for _, img := range images {
name := img.Name()
if name == ref || strings.HasSuffix(ref, "/"+name) || strings.HasSuffix(name, "/"+ref) {
return img, nil
}
if strings.HasPrefix(ref, "docker.io/library/") {
alt := strings.TrimPrefix(ref, "docker.io/library/")
if name == alt || strings.HasSuffix(name, "/"+alt) {
return img, nil
}
}
}
}
return nil, err
}
func (s *DefaultService) GetContainer(ctx context.Context, id string) (containerd.Container, error) {
if id == "" {
return nil, ErrInvalidArgument
@@ -580,3 +620,4 @@ func (s *DefaultService) SnapshotMounts(ctx context.Context, snapshotter, key st
func (s *DefaultService) withNamespace(ctx context.Context) context.Context {
return namespaces.WithNamespace(ctx, s.namespace)
}
+61
View File
@@ -0,0 +1,61 @@
package embeddings
import (
"context"
"github.com/memohai/memoh/internal/models"
)
// ResolverTextEmbedder adapts Resolver to the Embedder interface for text embeddings.
type ResolverTextEmbedder struct {
Resolver *Resolver
ModelID string
Dims int
}
func (e *ResolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) {
result, err := e.Resolver.Embed(ctx, Request{
Type: TypeText,
Model: e.ModelID,
Input: Input{Text: input},
})
if err != nil {
return nil, err
}
return result.Embedding, nil
}
func (e *ResolverTextEmbedder) Dimensions() int {
return e.Dims
}
// CollectEmbeddingVectors gathers embedding model dimensions and defaults.
func CollectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) {
candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding)
if err != nil {
return nil, models.GetResponse{}, models.GetResponse{}, false, err
}
vectors := map[string]int{}
var textModel models.GetResponse
var multimodalModel models.GetResponse
for _, model := range candidates {
if model.Dimensions > 0 && model.ModelID != "" {
vectors[model.ModelID] = model.Dimensions
}
if model.IsMultimodal {
if multimodalModel.ModelID == "" {
multimodalModel = model
}
continue
}
if textModel.ModelID == "" {
textModel = model
}
}
hasTextModel := textModel.ModelID != ""
hasMultimodalModel := multimodalModel.ModelID != ""
hasAnyModel := hasTextModel || hasMultimodalModel
return vectors, textModel, multimodalModel, hasAnyModel, nil
}
+188
View File
@@ -0,0 +1,188 @@
package handlers
import (
"net/http"
"strings"
"time"
"github.com/containerd/errdefs"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/config"
ctr "github.com/memohai/memoh/internal/containerd"
)
type ContainerdHandler struct {
service ctr.Service
cfg config.MCPConfig
namespace string
}
type CreateContainerRequest struct {
ContainerID string `json:"container_id"`
Image string `json:"image,omitempty"`
Snapshotter string `json:"snapshotter,omitempty"`
}
type CreateContainerResponse struct {
ContainerID string `json:"container_id"`
Image string `json:"image"`
Snapshotter string `json:"snapshotter"`
Started bool `json:"started"`
}
type CreateSnapshotRequest struct {
ContainerID string `json:"container_id"`
SnapshotName string `json:"snapshot_name"`
}
type CreateSnapshotResponse struct {
ContainerID string `json:"container_id"`
SnapshotName string `json:"snapshot_name"`
Snapshotter string `json:"snapshotter"`
}
func NewContainerdHandler(service ctr.Service, cfg config.MCPConfig, namespace string) *ContainerdHandler {
return &ContainerdHandler{
service: service,
cfg: cfg,
namespace: namespace,
}
}
func (h *ContainerdHandler) Register(e *echo.Echo) {
group := e.Group("/mcp")
group.POST("/containers", h.CreateContainer)
group.DELETE("/containers/:id", h.DeleteContainer)
group.POST("/snapshots", h.CreateSnapshot)
}
// CreateContainer godoc
// @Summary Create and start MCP container
// @Tags containerd
// @Param payload body CreateContainerRequest true "Create container payload"
// @Success 200 {object} CreateContainerResponse
// @Failure 400 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /mcp/containers [post]
func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
var req CreateContainerRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if strings.TrimSpace(req.ContainerID) == "" {
return echo.NewHTTPError(http.StatusBadRequest, "container_id is required")
}
image := strings.TrimSpace(req.Image)
if image == "" {
image = h.cfg.BusyboxImage
}
if image == "" {
image = config.DefaultBusyboxImg
}
snapshotter := strings.TrimSpace(req.Snapshotter)
if snapshotter == "" {
snapshotter = h.cfg.Snapshotter
}
if snapshotter == "" {
snapshotter = "overlayfs"
}
_, err := h.service.CreateContainer(c.Request().Context(), ctr.CreateContainerRequest{
ID: req.ContainerID,
ImageRef: image,
Snapshotter: snapshotter,
})
if err != nil && !errdefs.IsAlreadyExists(err) {
return echo.NewHTTPError(http.StatusInternalServerError, "snapshotter="+snapshotter+" image="+image+" err="+err.Error())
}
started := false
if _, err := h.service.StartTask(c.Request().Context(), req.ContainerID, &ctr.StartTaskOptions{
UseStdio: false,
}); err == nil {
started = true
}
return c.JSON(http.StatusOK, CreateContainerResponse{
ContainerID: req.ContainerID,
Image: image,
Snapshotter: snapshotter,
Started: started,
})
}
// DeleteContainer godoc
// @Summary Delete MCP container
// @Tags containerd
// @Param id path string true "Container ID"
// @Success 204
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /mcp/containers/{id} [delete]
func (h *ContainerdHandler) DeleteContainer(c echo.Context) error {
containerID := strings.TrimSpace(c.Param("id"))
if containerID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "container id is required")
}
_ = h.service.DeleteTask(c.Request().Context(), containerID, &ctr.DeleteTaskOptions{Force: true})
if err := h.service.DeleteContainer(c.Request().Context(), containerID, &ctr.DeleteContainerOptions{
CleanupSnapshot: true,
}); err != nil {
if errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusNotFound, "container not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusNoContent)
}
// CreateSnapshot godoc
// @Summary Create container snapshot
// @Tags containerd
// @Param payload body CreateSnapshotRequest true "Create snapshot payload"
// @Success 200 {object} CreateSnapshotResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /mcp/snapshots [post]
func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error {
var req CreateSnapshotRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if strings.TrimSpace(req.ContainerID) == "" {
return echo.NewHTTPError(http.StatusBadRequest, "container_id is required")
}
container, err := h.service.GetContainer(c.Request().Context(), req.ContainerID)
if err != nil {
if errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusNotFound, "container not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
ctx := c.Request().Context()
if strings.TrimSpace(h.namespace) != "" {
ctx = namespaces.WithNamespace(ctx, h.namespace)
}
info, err := container.Info(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
snapshotName := strings.TrimSpace(req.SnapshotName)
if snapshotName == "" {
snapshotName = req.ContainerID + "-" + time.Now().Format("20060102150405")
}
if err := h.service.CommitSnapshot(c.Request().Context(), info.Snapshotter, snapshotName, info.SnapshotKey); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, CreateSnapshotResponse{
ContainerID: req.ContainerID,
SnapshotName: snapshotName,
Snapshotter: info.Snapshotter,
})
}
+5
View File
@@ -0,0 +1,5 @@
package handlers
type ErrorResponse struct {
Message string `json:"message"`
}
-803
View File
@@ -1,803 +0,0 @@
package handlers
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/errdefs"
securejoin "github.com/cyphar/filepath-securejoin"
"github.com/labstack/echo/v4"
"github.com/pmezard/go-difflib/difflib"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/config"
ctr "github.com/memohai/memoh/internal/containerd"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/mcp"
)
type FSHandler struct {
service ctr.Service
manager *mcp.Manager
mcpConfig config.MCPConfig
namespace string
}
type ErrorResponse struct {
Message string `json:"message"`
}
type ReadResponse struct {
Path string `json:"path"`
Content string `json:"content"`
Encoding string `json:"encoding"`
Size int64 `json:"size"`
Mode uint32 `json:"mode"`
ModTime time.Time `json:"mod_time"`
}
type FileEntry struct {
Path string `json:"path"`
IsDir bool `json:"is_dir"`
Size int64 `json:"size"`
Mode uint32 `json:"mode"`
ModTime time.Time `json:"mod_time"`
}
type ListResponse struct {
Path string `json:"path"`
Entries []FileEntry `json:"entries"`
}
type WriteAtomicRequest struct {
Path string `json:"path"`
Content string `json:"content"`
Encoding string `json:"encoding"`
Mode *uint32 `json:"mode,omitempty"`
ModTime *time.Time `json:"mtime,omitempty"`
}
type ApplyPatchRequest struct {
Path string `json:"path"`
Patch string `json:"patch"`
}
type CommitResponse struct {
ID string `json:"id"`
Version int `json:"version"`
SnapshotID string `json:"snapshot_id"`
CreatedAt time.Time `json:"created_at"`
}
type DiffResponse struct {
Path string `json:"path"`
Version int `json:"version"`
Diff string `json:"diff"`
}
func NewFSHandler(service ctr.Service, manager *mcp.Manager, mcpConfig config.MCPConfig, namespace string) *FSHandler {
if namespace == "" {
namespace = config.DefaultNamespace
}
return &FSHandler{
service: service,
manager: manager,
mcpConfig: mcpConfig,
namespace: namespace,
}
}
func (h *FSHandler) Register(e *echo.Echo) {
group := e.Group("/fs")
group.GET("/read", h.Read)
group.GET("/list", h.List)
group.PUT("/write_atomic", h.WriteAtomic)
group.POST("/apply_patch", h.ApplyPatch)
group.POST("/commit", h.Commit)
group.GET("/diff", h.Diff)
}
// Read godoc
// @Summary Read file content
// @Description Read a file under the user data mount
// @Tags fs
// @Param path query string false "Path under data mount"
// @Success 200 {object} ReadResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /fs/read [get]
func (h *FSHandler) Read(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
mount, err := h.mountUser(ctx, userID)
if err != nil {
return err
}
defer mount.Unmount()
containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
hostPath, err := resolveHostPath(mount.Dir, containerPath)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
info, err := os.Stat(hostPath)
if err != nil {
if os.IsNotExist(err) {
return echo.NewHTTPError(http.StatusNotFound, "file not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if info.IsDir() {
return echo.NewHTTPError(http.StatusBadRequest, "path is a directory")
}
data, err := os.ReadFile(hostPath)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, ReadResponse{
Path: containerPath,
Content: base64.StdEncoding.EncodeToString(data),
Encoding: "base64",
Size: info.Size(),
Mode: uint32(info.Mode().Perm()),
ModTime: info.ModTime(),
})
}
// List godoc
// @Summary List directory contents
// @Description List files under the user data mount
// @Tags fs
// @Param path query string false "Path under data mount"
// @Param recursive query bool false "Recursive listing"
// @Success 200 {object} ListResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /fs/list [get]
func (h *FSHandler) List(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
mount, err := h.mountUser(ctx, userID)
if err != nil {
return err
}
defer mount.Unmount()
containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
hostPath, err := resolveHostPath(mount.Dir, containerPath)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
info, err := os.Stat(hostPath)
if err != nil {
if os.IsNotExist(err) {
return echo.NewHTTPError(http.StatusNotFound, "path not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if !info.IsDir() {
return echo.NewHTTPError(http.StatusBadRequest, "path is not a directory")
}
recursive := strings.EqualFold(c.QueryParam("recursive"), "true")
entries := []FileEntry{}
if recursive {
err = filepath.WalkDir(hostPath, func(p string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if p == hostPath {
return nil
}
entryInfo, err := d.Info()
if err != nil {
return err
}
containerEntry, err := containerPathForHost(mount.Dir, p)
if err != nil {
return err
}
entries = append(entries, FileEntry{
Path: containerEntry,
IsDir: d.IsDir(),
Size: entryInfo.Size(),
Mode: uint32(entryInfo.Mode().Perm()),
ModTime: entryInfo.ModTime(),
})
return nil
})
} else {
dirEntries, err := os.ReadDir(hostPath)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
for _, entry := range dirEntries {
entryInfo, err := entry.Info()
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
entryPath := filepath.Join(hostPath, entry.Name())
containerEntry, err := containerPathForHost(mount.Dir, entryPath)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
entries = append(entries, FileEntry{
Path: containerEntry,
IsDir: entry.IsDir(),
Size: entryInfo.Size(),
Mode: uint32(entryInfo.Mode().Perm()),
ModTime: entryInfo.ModTime(),
})
}
}
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, ListResponse{
Path: containerPath,
Entries: entries,
})
}
// WriteAtomic godoc
// @Summary Write file atomically
// @Description Atomically replace a file under the user data mount
// @Tags fs
// @Param payload body WriteAtomicRequest true "Write payload"
// @Success 204
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /fs/write_atomic [put]
func (h *FSHandler) WriteAtomic(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
var req WriteAtomicRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if req.Path == "" {
return echo.NewHTTPError(http.StatusBadRequest, "path is required")
}
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
mount, err := h.mountUser(ctx, userID)
if err != nil {
return err
}
defer mount.Unmount()
containerPath, err := resolveContainerPath(h.dataMount(), req.Path)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
hostPath, err := resolveHostPath(mount.Dir, containerPath)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
data, err := decodeContent(req.Content, req.Encoding)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
mode := os.FileMode(0o644)
if req.Mode != nil {
mode = os.FileMode(*req.Mode)
}
if err := writeFileAtomic(hostPath, data, mode, req.ModTime); err != nil {
if os.IsNotExist(err) {
return echo.NewHTTPError(http.StatusNotFound, "path not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusNoContent)
}
// ApplyPatch godoc
// @Summary Apply unified diff patch
// @Description Apply a unified diff patch to a file under the user data mount
// @Tags fs
// @Param payload body ApplyPatchRequest true "Patch payload"
// @Success 204
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /fs/apply_patch [post]
func (h *FSHandler) ApplyPatch(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
var req ApplyPatchRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if req.Path == "" || req.Patch == "" {
return echo.NewHTTPError(http.StatusBadRequest, "path and patch are required")
}
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
mount, err := h.mountUser(ctx, userID)
if err != nil {
return err
}
defer mount.Unmount()
containerPath, err := resolveContainerPath(h.dataMount(), req.Path)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
hostPath, err := resolveHostPath(mount.Dir, containerPath)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
orig, err := os.ReadFile(hostPath)
if err != nil {
if os.IsNotExist(err) {
return echo.NewHTTPError(http.StatusNotFound, "file not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
updated, err := applyUnifiedPatch(string(orig), req.Patch)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
info, err := os.Stat(hostPath)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if err := writeFileAtomic(hostPath, []byte(updated), info.Mode().Perm(), nil); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.NoContent(http.StatusNoContent)
}
// Commit godoc
// @Summary Commit a filesystem snapshot
// @Description Create a new version snapshot for the user container
// @Tags fs
// @Success 200 {object} CommitResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /fs/commit [post]
func (h *FSHandler) Commit(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
if h.manager == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "manager not configured")
}
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
if err := h.ensureUserContainer(ctx, userID); err != nil {
return err
}
info, err := h.manager.CreateVersion(ctx, userID)
if err != nil {
if errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusNotFound, "container not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, CommitResponse{
ID: info.ID,
Version: info.Version,
SnapshotID: info.SnapshotID,
CreatedAt: info.CreatedAt,
})
}
// Diff godoc
// @Summary Diff against a version snapshot
// @Description Produce a unified diff between a version snapshot and current data
// @Tags fs
// @Param path query string false "Path under data mount"
// @Param version query int true "Version number"
// @Success 200 {object} DiffResponse
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /fs/diff [get]
func (h *FSHandler) Diff(c echo.Context) error {
userID, err := h.requireUserID(c)
if err != nil {
return err
}
if h.manager == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "manager not configured")
}
versionStr := c.QueryParam("version")
if versionStr == "" {
return echo.NewHTTPError(http.StatusBadRequest, "version is required")
}
version, err := strconv.Atoi(versionStr)
if err != nil || version <= 0 {
return echo.NewHTTPError(http.StatusBadRequest, "invalid version")
}
containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
mount, err := h.mountUser(ctx, userID)
if err != nil {
return err
}
defer mount.Unmount()
versionSnapshotID, err := h.manager.VersionSnapshotID(ctx, userID, version)
if err != nil {
if errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusNotFound, "version not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
versionDir, versionCleanup, err := ctr.MountSnapshot(ctx, h.service, mount.Info.Snapshotter, versionSnapshotID)
if err != nil {
if errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusNotFound, "snapshot not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
defer versionCleanup()
currentHostPath, err := resolveHostPath(mount.Dir, containerPath)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
versionHostPath, err := resolveHostPath(versionDir, containerPath)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
currentContent, err := readFileOrEmpty(currentHostPath)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
versionContent, err := readFileOrEmpty(versionHostPath)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
diffText, err := unifiedDiff(containerPath, versionContent, currentContent)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, DiffResponse{
Path: containerPath,
Version: version,
Diff: diffText,
})
}
func (h *FSHandler) dataMount() string {
if h.mcpConfig.DataMount == "" {
return config.DefaultDataMount
}
return h.mcpConfig.DataMount
}
func (h *FSHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateUserID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return userID, nil
}
func (h *FSHandler) mountUser(ctx context.Context, userID string) (*ctr.MountedSnapshot, error) {
containerID := mcp.ContainerPrefix + userID
mount, err := ctr.MountContainerSnapshot(ctx, h.service, containerID)
if err != nil {
if errdefs.IsNotFound(err) {
return nil, echo.NewHTTPError(http.StatusNotFound, "container not found")
}
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if label, ok := mount.Info.Labels[mcp.UserLabelKey]; !ok || label != userID {
_ = mount.Unmount()
return nil, echo.NewHTTPError(http.StatusForbidden, "user mismatch")
}
return mount, nil
}
func (h *FSHandler) ensureUserContainer(ctx context.Context, userID string) error {
containerID := mcp.ContainerPrefix + userID
container, err := h.service.GetContainer(ctx, containerID)
if err != nil {
if errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusNotFound, "container not found")
}
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
info, err := container.Info(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if label, ok := info.Labels[mcp.UserLabelKey]; !ok || label != userID {
return echo.NewHTTPError(http.StatusForbidden, "user mismatch")
}
return nil
}
func resolveContainerPath(dataMount, requestPath string) (string, error) {
mountPath := path.Clean(dataMount)
if mountPath == "." || !strings.HasPrefix(mountPath, "/") {
return "", fmt.Errorf("data mount must be absolute")
}
if requestPath == "" {
return mountPath, nil
}
reqClean := path.Clean(requestPath)
if path.IsAbs(reqClean) {
if !pathWithin(reqClean, mountPath) {
return "", fmt.Errorf("path outside data mount")
}
return reqClean, nil
}
return path.Join(mountPath, reqClean), nil
}
func pathWithin(target, base string) bool {
if base == "/" {
return strings.HasPrefix(target, "/")
}
if target == base {
return true
}
if strings.HasPrefix(target, base) {
return len(target) > len(base) && target[len(base)] == '/'
}
return false
}
func resolveHostPath(mountDir, containerPath string) (string, error) {
rel := strings.TrimPrefix(containerPath, "/")
return securejoin.SecureJoin(mountDir, rel)
}
func containerPathForHost(mountDir, hostPath string) (string, error) {
rel, err := filepath.Rel(mountDir, hostPath)
if err != nil {
return "", err
}
if strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("path escapes mount")
}
return "/" + filepath.ToSlash(rel), nil
}
func decodeContent(content, encoding string) ([]byte, error) {
switch strings.ToLower(encoding) {
case "", "plain":
return []byte(content), nil
case "base64":
return base64.StdEncoding.DecodeString(content)
default:
return nil, fmt.Errorf("unsupported encoding")
}
}
func writeFileAtomic(targetPath string, data []byte, mode os.FileMode, modTime *time.Time) error {
dir := filepath.Dir(targetPath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
tmp, err := os.CreateTemp(dir, ".tmp-*")
if err != nil {
return err
}
tmpName := tmp.Name()
defer os.Remove(tmpName)
if _, err := io.Copy(tmp, bytes.NewReader(data)); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Chmod(mode); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Close(); err != nil {
return err
}
if modTime != nil {
if err := os.Chtimes(tmpName, *modTime, *modTime); err != nil {
return err
}
}
if err := os.Rename(tmpName, targetPath); err != nil {
return err
}
if modTime != nil {
_ = os.Chtimes(targetPath, *modTime, *modTime)
}
return nil
}
func applyUnifiedPatch(original, patch string) (string, error) {
lines := strings.Split(original, "\n")
out := make([]string, 0, len(lines))
index := 0
patchLines := strings.Split(patch, "\n")
hunksApplied := 0
for i := 0; i < len(patchLines); i++ {
line := patchLines[i]
if !strings.HasPrefix(line, "@@") {
continue
}
origStart, err := parseUnifiedHunkHeader(line)
if err != nil {
return "", err
}
origStart--
if origStart < 0 {
origStart = 0
}
if origStart > len(lines) {
return "", fmt.Errorf("patch out of range")
}
out = append(out, lines[index:origStart]...)
index = origStart
hunksApplied++
for i+1 < len(patchLines) {
next := patchLines[i+1]
if strings.HasPrefix(next, "@@") {
break
}
i++
if next == "" {
if i == len(patchLines)-1 {
break
}
return "", fmt.Errorf("invalid patch line")
}
if next[0] == '\\' {
continue
}
if len(next) < 1 {
return "", fmt.Errorf("invalid patch line")
}
op := next[0]
text := next[1:]
switch op {
case ' ':
if index >= len(lines) || lines[index] != text {
return "", fmt.Errorf("patch context mismatch")
}
out = append(out, text)
index++
case '-':
if index >= len(lines) || lines[index] != text {
return "", fmt.Errorf("patch delete mismatch")
}
index++
case '+':
out = append(out, text)
default:
return "", fmt.Errorf("invalid patch operation")
}
}
}
if hunksApplied == 0 {
return "", fmt.Errorf("patch contains no hunks")
}
out = append(out, lines[index:]...)
return strings.Join(out, "\n"), nil
}
func parseUnifiedHunkHeader(header string) (int, error) {
trimmed := strings.TrimPrefix(header, "@@")
trimmed = strings.TrimSpace(trimmed)
if !strings.HasPrefix(trimmed, "-") {
return 0, fmt.Errorf("invalid hunk header")
}
parts := strings.SplitN(trimmed, " ", 2)
if len(parts) < 2 {
return 0, fmt.Errorf("invalid hunk header")
}
origPart := strings.TrimPrefix(parts[0], "-")
origFields := strings.SplitN(origPart, ",", 2)
origStart, err := strconv.Atoi(origFields[0])
if err != nil {
return 0, fmt.Errorf("invalid hunk header")
}
return origStart, nil
}
func readFileOrEmpty(path string) (string, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
return string(data), nil
}
func unifiedDiff(containerPath, oldContent, newContent string) (string, error) {
diffText, err := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
A: strings.Split(oldContent, "\n"),
B: strings.Split(newContent, "\n"),
FromFile: "a" + containerPath,
ToFile: "b" + containerPath,
Context: 3,
})
if err != nil {
return "", err
}
return diffText, nil
}
+408
View File
@@ -0,0 +1,408 @@
package mcp
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"time"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
)
type EchoInput struct {
Text string `json:"text" jsonschema:"text to echo"`
}
type EchoOutput struct {
Text string `json:"text" jsonschema:"echoed text"`
}
type FSReadInput struct {
Path string `json:"path" jsonschema:"relative file path"`
}
type FSReadOutput struct {
Content string `json:"content" jsonschema:"file content"`
}
type FSWriteInput struct {
Path string `json:"path" jsonschema:"relative file path"`
Content string `json:"content" jsonschema:"file content"`
}
type FSWriteOutput struct {
OK bool `json:"ok" jsonschema:"write result"`
}
type FSListInput struct {
Path string `json:"path" jsonschema:"relative directory path"`
Recursive bool `json:"recursive" jsonschema:"recursive listing"`
}
type FSFileEntry struct {
Path string `json:"path" jsonschema:"relative entry path"`
IsDir bool `json:"is_dir" jsonschema:"is directory"`
Size int64 `json:"size" jsonschema:"entry size"`
Mode uint32 `json:"mode" jsonschema:"file mode"`
ModTime time.Time `json:"mod_time" jsonschema:"modification time"`
}
type FSListOutput struct {
Path string `json:"path" jsonschema:"listed path"`
Entries []FSFileEntry `json:"entries" jsonschema:"entries"`
}
type FSStatInput struct {
Path string `json:"path" jsonschema:"relative path"`
}
type FSStatOutput struct {
Entry FSFileEntry `json:"entry" jsonschema:"entry"`
}
type FSDeleteInput struct {
Path string `json:"path" jsonschema:"relative path"`
}
type FSDeleteOutput struct {
OK bool `json:"ok" jsonschema:"delete result"`
}
type FSApplyPatchInput struct {
Path string `json:"path" jsonschema:"relative file path"`
Patch string `json:"patch" jsonschema:"unified diff patch"`
}
type FSApplyPatchOutput struct {
OK bool `json:"ok" jsonschema:"apply result"`
}
func RegisterTools(server *sdkmcp.Server) {
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "echo", Description: "echo input text"}, echoTool)
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.read", Description: "read file content"}, fsReadTool)
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.write", Description: "write file content"}, fsWriteTool)
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.list", Description: "list directory entries"}, fsListTool)
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.stat", Description: "stat file or directory"}, fsStatTool)
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.delete", Description: "delete file or directory"}, fsDeleteTool)
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.apply_patch", Description: "apply unified diff patch"}, fsApplyPatchTool)
}
func echoTool(ctx context.Context, req *sdkmcp.CallToolRequest, input EchoInput) (
*sdkmcp.CallToolResult,
EchoOutput,
error,
) {
return nil, EchoOutput{Text: input.Text}, nil
}
func fsReadTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSReadInput) (
*sdkmcp.CallToolResult,
FSReadOutput,
error,
) {
root := dataRoot()
target, err := resolvePath(root, input.Path)
if err != nil {
return nil, FSReadOutput{}, err
}
data, err := os.ReadFile(target)
if err != nil {
return nil, FSReadOutput{}, err
}
return nil, FSReadOutput{Content: string(data)}, nil
}
func fsWriteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSWriteInput) (
*sdkmcp.CallToolResult,
FSWriteOutput,
error,
) {
root := dataRoot()
target, err := resolvePath(root, input.Path)
if err != nil {
return nil, FSWriteOutput{}, err
}
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return nil, FSWriteOutput{}, err
}
if err := os.WriteFile(target, []byte(input.Content), 0o644); err != nil {
return nil, FSWriteOutput{}, err
}
return nil, FSWriteOutput{OK: true}, nil
}
func fsListTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSListInput) (
*sdkmcp.CallToolResult,
FSListOutput,
error,
) {
root := dataRoot()
target, err := resolvePathAllowRoot(root, input.Path)
if err != nil {
return nil, FSListOutput{}, err
}
info, err := os.Stat(target)
if err != nil {
return nil, FSListOutput{}, err
}
if !info.IsDir() {
return nil, FSListOutput{}, fmt.Errorf("path is not a directory")
}
entries := []FSFileEntry{}
if input.Recursive {
err = filepath.WalkDir(target, func(p string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if p == target {
return nil
}
entryInfo, err := d.Info()
if err != nil {
return err
}
entry, err := entryForPath(root, p, entryInfo)
if err != nil {
return err
}
entries = append(entries, entry)
return nil
})
} else {
dirEntries, err := os.ReadDir(target)
if err != nil {
return nil, FSListOutput{}, err
}
for _, entry := range dirEntries {
entryInfo, err := entry.Info()
if err != nil {
return nil, FSListOutput{}, err
}
fullPath := filepath.Join(target, entry.Name())
fileEntry, err := entryForPath(root, fullPath, entryInfo)
if err != nil {
return nil, FSListOutput{}, err
}
entries = append(entries, fileEntry)
}
}
if err != nil {
return nil, FSListOutput{}, err
}
listedPath := strings.TrimSpace(input.Path)
if listedPath == "" {
listedPath = "."
}
return nil, FSListOutput{Path: listedPath, Entries: entries}, nil
}
func fsStatTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSStatInput) (
*sdkmcp.CallToolResult,
FSStatOutput,
error,
) {
root := dataRoot()
target, err := resolvePathAllowRoot(root, input.Path)
if err != nil {
return nil, FSStatOutput{}, err
}
info, err := os.Stat(target)
if err != nil {
return nil, FSStatOutput{}, err
}
entry, err := entryForPath(root, target, info)
if err != nil {
return nil, FSStatOutput{}, err
}
return nil, FSStatOutput{Entry: entry}, nil
}
func fsDeleteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSDeleteInput) (
*sdkmcp.CallToolResult,
FSDeleteOutput,
error,
) {
root := dataRoot()
target, err := resolvePath(root, input.Path)
if err != nil {
return nil, FSDeleteOutput{}, err
}
if err := os.RemoveAll(target); err != nil {
return nil, FSDeleteOutput{}, err
}
return nil, FSDeleteOutput{OK: true}, nil
}
func fsApplyPatchTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSApplyPatchInput) (
*sdkmcp.CallToolResult,
FSApplyPatchOutput,
error,
) {
root := dataRoot()
target, err := resolvePath(root, input.Path)
if err != nil {
return nil, FSApplyPatchOutput{}, err
}
orig, err := os.ReadFile(target)
if err != nil {
return nil, FSApplyPatchOutput{}, err
}
updated, err := applyUnifiedPatch(string(orig), input.Patch)
if err != nil {
return nil, FSApplyPatchOutput{}, err
}
info, err := os.Stat(target)
if err != nil {
return nil, FSApplyPatchOutput{}, err
}
if err := os.WriteFile(target, []byte(updated), info.Mode().Perm()); err != nil {
return nil, FSApplyPatchOutput{}, err
}
return nil, FSApplyPatchOutput{OK: true}, nil
}
func dataRoot() string {
root := strings.TrimSpace(os.Getenv("MCP_DATA_DIR"))
if root == "" {
root = "/data"
}
return root
}
func resolvePathAllowRoot(root, requestPath string) (string, error) {
if strings.TrimSpace(requestPath) == "" {
return root, nil
}
return resolvePath(root, requestPath)
}
func resolvePath(root, requestPath string) (string, error) {
clean := filepath.Clean(requestPath)
if clean == "." || clean == "" {
return "", os.ErrInvalid
}
if filepath.IsAbs(clean) || strings.HasPrefix(clean, "..") {
return "", os.ErrInvalid
}
return filepath.Join(root, clean), nil
}
func entryForPath(root, target string, info os.FileInfo) (FSFileEntry, error) {
rel, err := filepath.Rel(root, target)
if err != nil {
return FSFileEntry{}, err
}
if strings.HasPrefix(rel, "..") {
return FSFileEntry{}, os.ErrInvalid
}
if rel == "." {
rel = ""
}
return FSFileEntry{
Path: filepath.ToSlash(rel),
IsDir: info.IsDir(),
Size: info.Size(),
Mode: uint32(info.Mode().Perm()),
ModTime: info.ModTime(),
}, nil
}
func applyUnifiedPatch(original, patch string) (string, error) {
lines := strings.Split(original, "\n")
out := make([]string, 0, len(lines))
index := 0
patchLines := strings.Split(patch, "\n")
hunksApplied := 0
for i := 0; i < len(patchLines); i++ {
line := patchLines[i]
if !strings.HasPrefix(line, "@@") {
continue
}
origStart, err := parseUnifiedHunkHeader(line)
if err != nil {
return "", err
}
origStart--
if origStart < 0 {
origStart = 0
}
if origStart > len(lines) {
return "", fmt.Errorf("patch out of range")
}
out = append(out, lines[index:origStart]...)
index = origStart
hunksApplied++
for i+1 < len(patchLines) {
next := patchLines[i+1]
if strings.HasPrefix(next, "@@") {
break
}
i++
if next == "" {
if i == len(patchLines)-1 {
break
}
return "", fmt.Errorf("invalid patch line")
}
if next[0] == '\\' {
continue
}
op := next[0]
text := next[1:]
switch op {
case ' ':
if index >= len(lines) || lines[index] != text {
return "", fmt.Errorf("patch context mismatch")
}
out = append(out, text)
index++
case '-':
if index >= len(lines) || lines[index] != text {
return "", fmt.Errorf("patch delete mismatch")
}
index++
case '+':
out = append(out, text)
default:
return "", fmt.Errorf("invalid patch operation")
}
}
}
if hunksApplied == 0 {
return "", fmt.Errorf("patch contains no hunks")
}
out = append(out, lines[index:]...)
return strings.Join(out, "\n"), nil
}
func parseUnifiedHunkHeader(header string) (int, error) {
trimmed := strings.TrimPrefix(header, "@@")
trimmed = strings.TrimSpace(trimmed)
if !strings.HasPrefix(trimmed, "-") {
return 0, fmt.Errorf("invalid hunk header")
}
parts := strings.SplitN(trimmed, " ", 2)
if len(parts) < 2 {
return 0, fmt.Errorf("invalid hunk header")
}
origPart := strings.TrimPrefix(parts[0], "-")
origFields := strings.SplitN(origPart, ",", 2)
origStart, err := strconv.Atoi(origFields[0])
if err != nil {
return 0, fmt.Errorf("invalid hunk header")
}
return origStart, nil
}
+57
View File
@@ -0,0 +1,57 @@
package models
import (
"context"
"fmt"
"strings"
"github.com/memohai/memoh/internal/db/sqlc"
)
// SelectMemoryModel selects a chat model for memory operations.
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) {
// First try to get the memory-enabled model.
memoryModel, err := modelsService.GetByEnableAs(ctx, EnableAsMemory)
if err == nil {
provider, err := FetchProviderByID(ctx, queries, memoryModel.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return memoryModel, provider, nil
}
// Fallback to chat model.
chatModel, err := modelsService.GetByEnableAs(ctx, EnableAsChat)
if err == nil {
provider, err := FetchProviderByID(ctx, queries, chatModel.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return chatModel, provider, nil
}
// If no enabled models, try to find any chat model.
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
if err != nil || len(candidates) == 0 {
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
}
selected := candidates[0]
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
if err != nil {
return GetResponse{}, sqlc.LlmProvider{}, err
}
return selected, provider, nil
}
// FetchProviderByID fetches a provider by ID.
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
if strings.TrimSpace(providerID) == "" {
return sqlc.LlmProvider{}, fmt.Errorf("provider id missing")
}
parsed, err := parseUUID(providerID)
if err != nil {
return sqlc.LlmProvider{}, err
}
return queries.GetLlmProviderByID(ctx, parsed)
}
+7 -4
View File
@@ -15,7 +15,7 @@ type Server struct {
addr string
}
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, fsHandler *handlers.FSHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler) *Server {
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server {
if addr == "" {
addr = ":8080"
}
@@ -29,6 +29,9 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
if path == "/ping" || path == "/api/swagger.json" || path == "/auth/login" {
return true
}
if strings.HasPrefix(path, "/mcp/") {
return true
}
if strings.HasPrefix(path, "/api/docs") {
return true
}
@@ -47,9 +50,6 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
if embeddingsHandler != nil {
embeddingsHandler.Register(e)
}
if fsHandler != nil {
fsHandler.Register(e)
}
if swaggerHandler != nil {
swaggerHandler.Register(e)
}
@@ -62,6 +62,9 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
if modelsHandler != nil {
modelsHandler.Register(e)
}
if containerdHandler != nil {
containerdHandler.Register(e)
}
return &Server{
echo: e,