mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
refact: go mcp tool in containerd
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
// CreateProvider creates a chat provider instance.
|
||||
func CreateProvider(provider dbsqlc.LlmProvider, timeout time.Duration) (Provider, error) {
|
||||
clientType := strings.ToLower(strings.TrimSpace(provider.ClientType))
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
switch clientType {
|
||||
case ProviderOpenAI, ProviderOpenAICompat:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("openai api key is required")
|
||||
}
|
||||
return NewOpenAIProvider(provider.ApiKey, provider.BaseUrl, timeout)
|
||||
case ProviderAnthropic:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("anthropic api key is required")
|
||||
}
|
||||
return NewAnthropicProvider(provider.ApiKey, timeout)
|
||||
case ProviderGoogle:
|
||||
if strings.TrimSpace(provider.ApiKey) == "" {
|
||||
return nil, fmt.Errorf("google api key is required")
|
||||
}
|
||||
return NewGoogleProvider(provider.ApiKey, timeout)
|
||||
case ProviderOllama:
|
||||
return NewOllamaProvider(provider.BaseUrl, timeout)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider type: %s", clientType)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tasksv1 "github.com/containerd/containerd/api/services/tasks/v1"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/containerd/containerd/v2/defaults"
|
||||
"github.com/containerd/containerd/v2/pkg/cio"
|
||||
"github.com/containerd/containerd/v2/pkg/namespaces"
|
||||
"github.com/containerd/errdefs"
|
||||
"github.com/containerd/containerd/v2/pkg/oci"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
)
|
||||
@@ -179,13 +181,21 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine
|
||||
}
|
||||
|
||||
ctx = s.withNamespace(ctx)
|
||||
pullOpts := &PullImageOptions{
|
||||
Unpack: true,
|
||||
Snapshotter: req.Snapshotter,
|
||||
}
|
||||
image, err := s.PullImage(ctx, req.ImageRef, pullOpts)
|
||||
image, err := s.getImageWithFallback(ctx, req.ImageRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
pullOpts := &PullImageOptions{
|
||||
Unpack: true,
|
||||
Snapshotter: req.Snapshotter,
|
||||
}
|
||||
image, err = s.PullImage(ctx, req.ImageRef, pullOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if req.Snapshotter != "" {
|
||||
if err := image.Unpack(ctx, req.Snapshotter); err != nil && !errdefs.IsAlreadyExists(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
snapshotID := req.SnapshotID
|
||||
@@ -224,6 +234,36 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine
|
||||
return s.client.NewContainer(ctx, req.ID, containerOpts...)
|
||||
}
|
||||
|
||||
func (s *DefaultService) getImageWithFallback(ctx context.Context, ref string) (containerd.Image, error) {
|
||||
image, err := s.GetImage(ctx, ref)
|
||||
if err == nil {
|
||||
return image, nil
|
||||
}
|
||||
if strings.HasPrefix(ref, "docker.io/library/") {
|
||||
alt := strings.TrimPrefix(ref, "docker.io/library/")
|
||||
image, altErr := s.GetImage(ctx, alt)
|
||||
if altErr == nil {
|
||||
return image, nil
|
||||
}
|
||||
}
|
||||
images, listErr := s.ListImages(ctx)
|
||||
if listErr == nil {
|
||||
for _, img := range images {
|
||||
name := img.Name()
|
||||
if name == ref || strings.HasSuffix(ref, "/"+name) || strings.HasSuffix(name, "/"+ref) {
|
||||
return img, nil
|
||||
}
|
||||
if strings.HasPrefix(ref, "docker.io/library/") {
|
||||
alt := strings.TrimPrefix(ref, "docker.io/library/")
|
||||
if name == alt || strings.HasSuffix(name, "/"+alt) {
|
||||
return img, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *DefaultService) GetContainer(ctx context.Context, id string) (containerd.Container, error) {
|
||||
if id == "" {
|
||||
return nil, ErrInvalidArgument
|
||||
@@ -580,3 +620,4 @@ func (s *DefaultService) SnapshotMounts(ctx context.Context, snapshotter, key st
|
||||
func (s *DefaultService) withNamespace(ctx context.Context) context.Context {
|
||||
return namespaces.WithNamespace(ctx, s.namespace)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
|
||||
// ResolverTextEmbedder adapts Resolver to the Embedder interface for text embeddings.
|
||||
type ResolverTextEmbedder struct {
|
||||
Resolver *Resolver
|
||||
ModelID string
|
||||
Dims int
|
||||
}
|
||||
|
||||
func (e *ResolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) {
|
||||
result, err := e.Resolver.Embed(ctx, Request{
|
||||
Type: TypeText,
|
||||
Model: e.ModelID,
|
||||
Input: Input{Text: input},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Embedding, nil
|
||||
}
|
||||
|
||||
func (e *ResolverTextEmbedder) Dimensions() int {
|
||||
return e.Dims
|
||||
}
|
||||
|
||||
// CollectEmbeddingVectors gathers embedding model dimensions and defaults.
|
||||
func CollectEmbeddingVectors(ctx context.Context, service *models.Service) (map[string]int, models.GetResponse, models.GetResponse, bool, error) {
|
||||
candidates, err := service.ListByType(ctx, models.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return nil, models.GetResponse{}, models.GetResponse{}, false, err
|
||||
}
|
||||
vectors := map[string]int{}
|
||||
var textModel models.GetResponse
|
||||
var multimodalModel models.GetResponse
|
||||
for _, model := range candidates {
|
||||
if model.Dimensions > 0 && model.ModelID != "" {
|
||||
vectors[model.ModelID] = model.Dimensions
|
||||
}
|
||||
if model.IsMultimodal {
|
||||
if multimodalModel.ModelID == "" {
|
||||
multimodalModel = model
|
||||
}
|
||||
continue
|
||||
}
|
||||
if textModel.ModelID == "" {
|
||||
textModel = model
|
||||
}
|
||||
}
|
||||
|
||||
hasTextModel := textModel.ModelID != ""
|
||||
hasMultimodalModel := multimodalModel.ModelID != ""
|
||||
hasAnyModel := hasTextModel || hasMultimodalModel
|
||||
|
||||
return vectors, textModel, multimodalModel, hasAnyModel, nil
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containerd/errdefs"
|
||||
"github.com/containerd/containerd/v2/pkg/namespaces"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
)
|
||||
|
||||
type ContainerdHandler struct {
|
||||
service ctr.Service
|
||||
cfg config.MCPConfig
|
||||
namespace string
|
||||
}
|
||||
|
||||
type CreateContainerRequest struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Snapshotter string `json:"snapshotter,omitempty"`
|
||||
}
|
||||
|
||||
type CreateContainerResponse struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
Image string `json:"image"`
|
||||
Snapshotter string `json:"snapshotter"`
|
||||
Started bool `json:"started"`
|
||||
}
|
||||
|
||||
type CreateSnapshotRequest struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
SnapshotName string `json:"snapshot_name"`
|
||||
}
|
||||
|
||||
type CreateSnapshotResponse struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
SnapshotName string `json:"snapshot_name"`
|
||||
Snapshotter string `json:"snapshotter"`
|
||||
}
|
||||
|
||||
func NewContainerdHandler(service ctr.Service, cfg config.MCPConfig, namespace string) *ContainerdHandler {
|
||||
return &ContainerdHandler{
|
||||
service: service,
|
||||
cfg: cfg,
|
||||
namespace: namespace,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) Register(e *echo.Echo) {
|
||||
group := e.Group("/mcp")
|
||||
group.POST("/containers", h.CreateContainer)
|
||||
group.DELETE("/containers/:id", h.DeleteContainer)
|
||||
group.POST("/snapshots", h.CreateSnapshot)
|
||||
}
|
||||
|
||||
// CreateContainer godoc
|
||||
// @Summary Create and start MCP container
|
||||
// @Tags containerd
|
||||
// @Param payload body CreateContainerRequest true "Create container payload"
|
||||
// @Success 200 {object} CreateContainerResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /mcp/containers [post]
|
||||
func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
var req CreateContainerRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if strings.TrimSpace(req.ContainerID) == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "container_id is required")
|
||||
}
|
||||
|
||||
image := strings.TrimSpace(req.Image)
|
||||
if image == "" {
|
||||
image = h.cfg.BusyboxImage
|
||||
}
|
||||
if image == "" {
|
||||
image = config.DefaultBusyboxImg
|
||||
}
|
||||
snapshotter := strings.TrimSpace(req.Snapshotter)
|
||||
if snapshotter == "" {
|
||||
snapshotter = h.cfg.Snapshotter
|
||||
}
|
||||
if snapshotter == "" {
|
||||
snapshotter = "overlayfs"
|
||||
}
|
||||
|
||||
_, err := h.service.CreateContainer(c.Request().Context(), ctr.CreateContainerRequest{
|
||||
ID: req.ContainerID,
|
||||
ImageRef: image,
|
||||
Snapshotter: snapshotter,
|
||||
})
|
||||
if err != nil && !errdefs.IsAlreadyExists(err) {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "snapshotter="+snapshotter+" image="+image+" err="+err.Error())
|
||||
}
|
||||
|
||||
started := false
|
||||
if _, err := h.service.StartTask(c.Request().Context(), req.ContainerID, &ctr.StartTaskOptions{
|
||||
UseStdio: false,
|
||||
}); err == nil {
|
||||
started = true
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, CreateContainerResponse{
|
||||
ContainerID: req.ContainerID,
|
||||
Image: image,
|
||||
Snapshotter: snapshotter,
|
||||
Started: started,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteContainer godoc
|
||||
// @Summary Delete MCP container
|
||||
// @Tags containerd
|
||||
// @Param id path string true "Container ID"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /mcp/containers/{id} [delete]
|
||||
func (h *ContainerdHandler) DeleteContainer(c echo.Context) error {
|
||||
containerID := strings.TrimSpace(c.Param("id"))
|
||||
if containerID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "container id is required")
|
||||
}
|
||||
_ = h.service.DeleteTask(c.Request().Context(), containerID, &ctr.DeleteTaskOptions{Force: true})
|
||||
if err := h.service.DeleteContainer(c.Request().Context(), containerID, &ctr.DeleteContainerOptions{
|
||||
CleanupSnapshot: true,
|
||||
}); err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "container not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// CreateSnapshot godoc
|
||||
// @Summary Create container snapshot
|
||||
// @Tags containerd
|
||||
// @Param payload body CreateSnapshotRequest true "Create snapshot payload"
|
||||
// @Success 200 {object} CreateSnapshotResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /mcp/snapshots [post]
|
||||
func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error {
|
||||
var req CreateSnapshotRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if strings.TrimSpace(req.ContainerID) == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "container_id is required")
|
||||
}
|
||||
container, err := h.service.GetContainer(c.Request().Context(), req.ContainerID)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "container not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
ctx := c.Request().Context()
|
||||
if strings.TrimSpace(h.namespace) != "" {
|
||||
ctx = namespaces.WithNamespace(ctx, h.namespace)
|
||||
}
|
||||
info, err := container.Info(ctx)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
snapshotName := strings.TrimSpace(req.SnapshotName)
|
||||
if snapshotName == "" {
|
||||
snapshotName = req.ContainerID + "-" + time.Now().Format("20060102150405")
|
||||
}
|
||||
if err := h.service.CommitSnapshot(c.Request().Context(), info.Snapshotter, snapshotName, info.SnapshotKey); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, CreateSnapshotResponse{
|
||||
ContainerID: req.ContainerID,
|
||||
SnapshotName: snapshotName,
|
||||
Snapshotter: info.Snapshotter,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package handlers
|
||||
|
||||
type ErrorResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
@@ -1,803 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containerd/containerd/v2/pkg/namespaces"
|
||||
"github.com/containerd/errdefs"
|
||||
securejoin "github.com/cyphar/filepath-securejoin"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pmezard/go-difflib/difflib"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
)
|
||||
|
||||
type FSHandler struct {
|
||||
service ctr.Service
|
||||
manager *mcp.Manager
|
||||
mcpConfig config.MCPConfig
|
||||
namespace string
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ReadResponse struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
Encoding string `json:"encoding"`
|
||||
Size int64 `json:"size"`
|
||||
Mode uint32 `json:"mode"`
|
||||
ModTime time.Time `json:"mod_time"`
|
||||
}
|
||||
|
||||
type FileEntry struct {
|
||||
Path string `json:"path"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
Size int64 `json:"size"`
|
||||
Mode uint32 `json:"mode"`
|
||||
ModTime time.Time `json:"mod_time"`
|
||||
}
|
||||
|
||||
type ListResponse struct {
|
||||
Path string `json:"path"`
|
||||
Entries []FileEntry `json:"entries"`
|
||||
}
|
||||
|
||||
type WriteAtomicRequest struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
Encoding string `json:"encoding"`
|
||||
Mode *uint32 `json:"mode,omitempty"`
|
||||
ModTime *time.Time `json:"mtime,omitempty"`
|
||||
}
|
||||
|
||||
type ApplyPatchRequest struct {
|
||||
Path string `json:"path"`
|
||||
Patch string `json:"patch"`
|
||||
}
|
||||
|
||||
type CommitResponse struct {
|
||||
ID string `json:"id"`
|
||||
Version int `json:"version"`
|
||||
SnapshotID string `json:"snapshot_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type DiffResponse struct {
|
||||
Path string `json:"path"`
|
||||
Version int `json:"version"`
|
||||
Diff string `json:"diff"`
|
||||
}
|
||||
|
||||
func NewFSHandler(service ctr.Service, manager *mcp.Manager, mcpConfig config.MCPConfig, namespace string) *FSHandler {
|
||||
if namespace == "" {
|
||||
namespace = config.DefaultNamespace
|
||||
}
|
||||
return &FSHandler{
|
||||
service: service,
|
||||
manager: manager,
|
||||
mcpConfig: mcpConfig,
|
||||
namespace: namespace,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FSHandler) Register(e *echo.Echo) {
|
||||
group := e.Group("/fs")
|
||||
group.GET("/read", h.Read)
|
||||
group.GET("/list", h.List)
|
||||
group.PUT("/write_atomic", h.WriteAtomic)
|
||||
group.POST("/apply_patch", h.ApplyPatch)
|
||||
group.POST("/commit", h.Commit)
|
||||
group.GET("/diff", h.Diff)
|
||||
}
|
||||
|
||||
// Read godoc
|
||||
// @Summary Read file content
|
||||
// @Description Read a file under the user data mount
|
||||
// @Tags fs
|
||||
// @Param path query string false "Path under data mount"
|
||||
// @Success 200 {object} ReadResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /fs/read [get]
|
||||
func (h *FSHandler) Read(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
|
||||
mount, err := h.mountUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mount.Unmount()
|
||||
|
||||
containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path"))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
hostPath, err := resolveHostPath(mount.Dir, containerPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
info, err := os.Stat(hostPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "file not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if info.IsDir() {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "path is a directory")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(hostPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, ReadResponse{
|
||||
Path: containerPath,
|
||||
Content: base64.StdEncoding.EncodeToString(data),
|
||||
Encoding: "base64",
|
||||
Size: info.Size(),
|
||||
Mode: uint32(info.Mode().Perm()),
|
||||
ModTime: info.ModTime(),
|
||||
})
|
||||
}
|
||||
|
||||
// List godoc
|
||||
// @Summary List directory contents
|
||||
// @Description List files under the user data mount
|
||||
// @Tags fs
|
||||
// @Param path query string false "Path under data mount"
|
||||
// @Param recursive query bool false "Recursive listing"
|
||||
// @Success 200 {object} ListResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /fs/list [get]
|
||||
func (h *FSHandler) List(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
|
||||
mount, err := h.mountUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mount.Unmount()
|
||||
|
||||
containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path"))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
hostPath, err := resolveHostPath(mount.Dir, containerPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
info, err := os.Stat(hostPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "path not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "path is not a directory")
|
||||
}
|
||||
|
||||
recursive := strings.EqualFold(c.QueryParam("recursive"), "true")
|
||||
entries := []FileEntry{}
|
||||
if recursive {
|
||||
err = filepath.WalkDir(hostPath, func(p string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if p == hostPath {
|
||||
return nil
|
||||
}
|
||||
entryInfo, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
containerEntry, err := containerPathForHost(mount.Dir, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entries = append(entries, FileEntry{
|
||||
Path: containerEntry,
|
||||
IsDir: d.IsDir(),
|
||||
Size: entryInfo.Size(),
|
||||
Mode: uint32(entryInfo.Mode().Perm()),
|
||||
ModTime: entryInfo.ModTime(),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
dirEntries, err := os.ReadDir(hostPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
for _, entry := range dirEntries {
|
||||
entryInfo, err := entry.Info()
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
entryPath := filepath.Join(hostPath, entry.Name())
|
||||
containerEntry, err := containerPathForHost(mount.Dir, entryPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
entries = append(entries, FileEntry{
|
||||
Path: containerEntry,
|
||||
IsDir: entry.IsDir(),
|
||||
Size: entryInfo.Size(),
|
||||
Mode: uint32(entryInfo.Mode().Perm()),
|
||||
ModTime: entryInfo.ModTime(),
|
||||
})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, ListResponse{
|
||||
Path: containerPath,
|
||||
Entries: entries,
|
||||
})
|
||||
}
|
||||
|
||||
// WriteAtomic godoc
|
||||
// @Summary Write file atomically
|
||||
// @Description Atomically replace a file under the user data mount
|
||||
// @Tags fs
|
||||
// @Param payload body WriteAtomicRequest true "Write payload"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /fs/write_atomic [put]
|
||||
func (h *FSHandler) WriteAtomic(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req WriteAtomicRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if req.Path == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "path is required")
|
||||
}
|
||||
|
||||
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
|
||||
mount, err := h.mountUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mount.Unmount()
|
||||
|
||||
containerPath, err := resolveContainerPath(h.dataMount(), req.Path)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
hostPath, err := resolveHostPath(mount.Dir, containerPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
data, err := decodeContent(req.Content, req.Encoding)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
mode := os.FileMode(0o644)
|
||||
if req.Mode != nil {
|
||||
mode = os.FileMode(*req.Mode)
|
||||
}
|
||||
|
||||
if err := writeFileAtomic(hostPath, data, mode, req.ModTime); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "path not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ApplyPatch godoc
|
||||
// @Summary Apply unified diff patch
|
||||
// @Description Apply a unified diff patch to a file under the user data mount
|
||||
// @Tags fs
|
||||
// @Param payload body ApplyPatchRequest true "Patch payload"
|
||||
// @Success 204
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /fs/apply_patch [post]
|
||||
func (h *FSHandler) ApplyPatch(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req ApplyPatchRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if req.Path == "" || req.Patch == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "path and patch are required")
|
||||
}
|
||||
|
||||
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
|
||||
mount, err := h.mountUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mount.Unmount()
|
||||
|
||||
containerPath, err := resolveContainerPath(h.dataMount(), req.Path)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
hostPath, err := resolveHostPath(mount.Dir, containerPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
orig, err := os.ReadFile(hostPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "file not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
updated, err := applyUnifiedPatch(string(orig), req.Patch)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
info, err := os.Stat(hostPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
if err := writeFileAtomic(hostPath, []byte(updated), info.Mode().Perm(), nil); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Commit godoc
|
||||
// @Summary Commit a filesystem snapshot
|
||||
// @Description Create a new version snapshot for the user container
|
||||
// @Tags fs
|
||||
// @Success 200 {object} CommitResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /fs/commit [post]
|
||||
func (h *FSHandler) Commit(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.manager == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "manager not configured")
|
||||
}
|
||||
|
||||
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
|
||||
if err := h.ensureUserContainer(ctx, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := h.manager.CreateVersion(ctx, userID)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "container not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, CommitResponse{
|
||||
ID: info.ID,
|
||||
Version: info.Version,
|
||||
SnapshotID: info.SnapshotID,
|
||||
CreatedAt: info.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
// Diff godoc
|
||||
// @Summary Diff against a version snapshot
|
||||
// @Description Produce a unified diff between a version snapshot and current data
|
||||
// @Tags fs
|
||||
// @Param path query string false "Path under data mount"
|
||||
// @Param version query int true "Version number"
|
||||
// @Success 200 {object} DiffResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /fs/diff [get]
|
||||
func (h *FSHandler) Diff(c echo.Context) error {
|
||||
userID, err := h.requireUserID(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if h.manager == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "manager not configured")
|
||||
}
|
||||
|
||||
versionStr := c.QueryParam("version")
|
||||
if versionStr == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "version is required")
|
||||
}
|
||||
version, err := strconv.Atoi(versionStr)
|
||||
if err != nil || version <= 0 {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid version")
|
||||
}
|
||||
|
||||
containerPath, err := resolveContainerPath(h.dataMount(), c.QueryParam("path"))
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
ctx := namespaces.WithNamespace(c.Request().Context(), h.namespace)
|
||||
mount, err := h.mountUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mount.Unmount()
|
||||
|
||||
versionSnapshotID, err := h.manager.VersionSnapshotID(ctx, userID, version)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "version not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
versionDir, versionCleanup, err := ctr.MountSnapshot(ctx, h.service, mount.Info.Snapshotter, versionSnapshotID)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "snapshot not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
defer versionCleanup()
|
||||
|
||||
currentHostPath, err := resolveHostPath(mount.Dir, containerPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
versionHostPath, err := resolveHostPath(versionDir, containerPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
currentContent, err := readFileOrEmpty(currentHostPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
versionContent, err := readFileOrEmpty(versionHostPath)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
diffText, err := unifiedDiff(containerPath, versionContent, currentContent)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, DiffResponse{
|
||||
Path: containerPath,
|
||||
Version: version,
|
||||
Diff: diffText,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FSHandler) dataMount() string {
|
||||
if h.mcpConfig.DataMount == "" {
|
||||
return config.DefaultDataMount
|
||||
}
|
||||
return h.mcpConfig.DataMount
|
||||
}
|
||||
|
||||
func (h *FSHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateUserID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (h *FSHandler) mountUser(ctx context.Context, userID string) (*ctr.MountedSnapshot, error) {
|
||||
containerID := mcp.ContainerPrefix + userID
|
||||
mount, err := ctr.MountContainerSnapshot(ctx, h.service, containerID)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
return nil, echo.NewHTTPError(http.StatusNotFound, "container not found")
|
||||
}
|
||||
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if label, ok := mount.Info.Labels[mcp.UserLabelKey]; !ok || label != userID {
|
||||
_ = mount.Unmount()
|
||||
return nil, echo.NewHTTPError(http.StatusForbidden, "user mismatch")
|
||||
}
|
||||
return mount, nil
|
||||
}
|
||||
|
||||
func (h *FSHandler) ensureUserContainer(ctx context.Context, userID string) error {
|
||||
containerID := mcp.ContainerPrefix + userID
|
||||
container, err := h.service.GetContainer(ctx, containerID)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "container not found")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
info, err := container.Info(ctx)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
if label, ok := info.Labels[mcp.UserLabelKey]; !ok || label != userID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "user mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveContainerPath(dataMount, requestPath string) (string, error) {
|
||||
mountPath := path.Clean(dataMount)
|
||||
if mountPath == "." || !strings.HasPrefix(mountPath, "/") {
|
||||
return "", fmt.Errorf("data mount must be absolute")
|
||||
}
|
||||
|
||||
if requestPath == "" {
|
||||
return mountPath, nil
|
||||
}
|
||||
|
||||
reqClean := path.Clean(requestPath)
|
||||
if path.IsAbs(reqClean) {
|
||||
if !pathWithin(reqClean, mountPath) {
|
||||
return "", fmt.Errorf("path outside data mount")
|
||||
}
|
||||
return reqClean, nil
|
||||
}
|
||||
|
||||
return path.Join(mountPath, reqClean), nil
|
||||
}
|
||||
|
||||
func pathWithin(target, base string) bool {
|
||||
if base == "/" {
|
||||
return strings.HasPrefix(target, "/")
|
||||
}
|
||||
if target == base {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(target, base) {
|
||||
return len(target) > len(base) && target[len(base)] == '/'
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveHostPath(mountDir, containerPath string) (string, error) {
|
||||
rel := strings.TrimPrefix(containerPath, "/")
|
||||
return securejoin.SecureJoin(mountDir, rel)
|
||||
}
|
||||
|
||||
func containerPathForHost(mountDir, hostPath string) (string, error) {
|
||||
rel, err := filepath.Rel(mountDir, hostPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.HasPrefix(rel, "..") {
|
||||
return "", fmt.Errorf("path escapes mount")
|
||||
}
|
||||
return "/" + filepath.ToSlash(rel), nil
|
||||
}
|
||||
|
||||
func decodeContent(content, encoding string) ([]byte, error) {
|
||||
switch strings.ToLower(encoding) {
|
||||
case "", "plain":
|
||||
return []byte(content), nil
|
||||
case "base64":
|
||||
return base64.StdEncoding.DecodeString(content)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported encoding")
|
||||
}
|
||||
}
|
||||
|
||||
func writeFileAtomic(targetPath string, data []byte, mode os.FileMode, modTime *time.Time) error {
|
||||
dir := filepath.Dir(targetPath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp(dir, ".tmp-*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpName := tmp.Name()
|
||||
defer os.Remove(tmpName)
|
||||
|
||||
if _, err := io.Copy(tmp, bytes.NewReader(data)); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Chmod(mode); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if modTime != nil {
|
||||
if err := os.Chtimes(tmpName, *modTime, *modTime); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := os.Rename(tmpName, targetPath); err != nil {
|
||||
return err
|
||||
}
|
||||
if modTime != nil {
|
||||
_ = os.Chtimes(targetPath, *modTime, *modTime)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyUnifiedPatch(original, patch string) (string, error) {
|
||||
lines := strings.Split(original, "\n")
|
||||
out := make([]string, 0, len(lines))
|
||||
index := 0
|
||||
patchLines := strings.Split(patch, "\n")
|
||||
hunksApplied := 0
|
||||
|
||||
for i := 0; i < len(patchLines); i++ {
|
||||
line := patchLines[i]
|
||||
if !strings.HasPrefix(line, "@@") {
|
||||
continue
|
||||
}
|
||||
|
||||
origStart, err := parseUnifiedHunkHeader(line)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
origStart--
|
||||
if origStart < 0 {
|
||||
origStart = 0
|
||||
}
|
||||
if origStart > len(lines) {
|
||||
return "", fmt.Errorf("patch out of range")
|
||||
}
|
||||
|
||||
out = append(out, lines[index:origStart]...)
|
||||
index = origStart
|
||||
hunksApplied++
|
||||
|
||||
for i+1 < len(patchLines) {
|
||||
next := patchLines[i+1]
|
||||
if strings.HasPrefix(next, "@@") {
|
||||
break
|
||||
}
|
||||
i++
|
||||
|
||||
if next == "" {
|
||||
if i == len(patchLines)-1 {
|
||||
break
|
||||
}
|
||||
return "", fmt.Errorf("invalid patch line")
|
||||
}
|
||||
if next[0] == '\\' {
|
||||
continue
|
||||
}
|
||||
if len(next) < 1 {
|
||||
return "", fmt.Errorf("invalid patch line")
|
||||
}
|
||||
op := next[0]
|
||||
text := next[1:]
|
||||
switch op {
|
||||
case ' ':
|
||||
if index >= len(lines) || lines[index] != text {
|
||||
return "", fmt.Errorf("patch context mismatch")
|
||||
}
|
||||
out = append(out, text)
|
||||
index++
|
||||
case '-':
|
||||
if index >= len(lines) || lines[index] != text {
|
||||
return "", fmt.Errorf("patch delete mismatch")
|
||||
}
|
||||
index++
|
||||
case '+':
|
||||
out = append(out, text)
|
||||
default:
|
||||
return "", fmt.Errorf("invalid patch operation")
|
||||
}
|
||||
}
|
||||
}
|
||||
if hunksApplied == 0 {
|
||||
return "", fmt.Errorf("patch contains no hunks")
|
||||
}
|
||||
|
||||
out = append(out, lines[index:]...)
|
||||
return strings.Join(out, "\n"), nil
|
||||
}
|
||||
|
||||
func parseUnifiedHunkHeader(header string) (int, error) {
|
||||
trimmed := strings.TrimPrefix(header, "@@")
|
||||
trimmed = strings.TrimSpace(trimmed)
|
||||
if !strings.HasPrefix(trimmed, "-") {
|
||||
return 0, fmt.Errorf("invalid hunk header")
|
||||
}
|
||||
parts := strings.SplitN(trimmed, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return 0, fmt.Errorf("invalid hunk header")
|
||||
}
|
||||
|
||||
origPart := strings.TrimPrefix(parts[0], "-")
|
||||
origFields := strings.SplitN(origPart, ",", 2)
|
||||
origStart, err := strconv.Atoi(origFields[0])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid hunk header")
|
||||
}
|
||||
return origStart, nil
|
||||
}
|
||||
|
||||
func readFileOrEmpty(path string) (string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func unifiedDiff(containerPath, oldContent, newContent string) (string, error) {
|
||||
diffText, err := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
|
||||
A: strings.Split(oldContent, "\n"),
|
||||
B: strings.Split(newContent, "\n"),
|
||||
FromFile: "a" + containerPath,
|
||||
ToFile: "b" + containerPath,
|
||||
Context: 3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return diffText, nil
|
||||
}
|
||||
@@ -0,0 +1,408 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
type EchoInput struct {
|
||||
Text string `json:"text" jsonschema:"text to echo"`
|
||||
}
|
||||
|
||||
type EchoOutput struct {
|
||||
Text string `json:"text" jsonschema:"echoed text"`
|
||||
}
|
||||
|
||||
type FSReadInput struct {
|
||||
Path string `json:"path" jsonschema:"relative file path"`
|
||||
}
|
||||
|
||||
type FSReadOutput struct {
|
||||
Content string `json:"content" jsonschema:"file content"`
|
||||
}
|
||||
|
||||
type FSWriteInput struct {
|
||||
Path string `json:"path" jsonschema:"relative file path"`
|
||||
Content string `json:"content" jsonschema:"file content"`
|
||||
}
|
||||
|
||||
type FSWriteOutput struct {
|
||||
OK bool `json:"ok" jsonschema:"write result"`
|
||||
}
|
||||
|
||||
type FSListInput struct {
|
||||
Path string `json:"path" jsonschema:"relative directory path"`
|
||||
Recursive bool `json:"recursive" jsonschema:"recursive listing"`
|
||||
}
|
||||
|
||||
type FSFileEntry struct {
|
||||
Path string `json:"path" jsonschema:"relative entry path"`
|
||||
IsDir bool `json:"is_dir" jsonschema:"is directory"`
|
||||
Size int64 `json:"size" jsonschema:"entry size"`
|
||||
Mode uint32 `json:"mode" jsonschema:"file mode"`
|
||||
ModTime time.Time `json:"mod_time" jsonschema:"modification time"`
|
||||
}
|
||||
|
||||
type FSListOutput struct {
|
||||
Path string `json:"path" jsonschema:"listed path"`
|
||||
Entries []FSFileEntry `json:"entries" jsonschema:"entries"`
|
||||
}
|
||||
|
||||
type FSStatInput struct {
|
||||
Path string `json:"path" jsonschema:"relative path"`
|
||||
}
|
||||
|
||||
type FSStatOutput struct {
|
||||
Entry FSFileEntry `json:"entry" jsonschema:"entry"`
|
||||
}
|
||||
|
||||
type FSDeleteInput struct {
|
||||
Path string `json:"path" jsonschema:"relative path"`
|
||||
}
|
||||
|
||||
type FSDeleteOutput struct {
|
||||
OK bool `json:"ok" jsonschema:"delete result"`
|
||||
}
|
||||
|
||||
type FSApplyPatchInput struct {
|
||||
Path string `json:"path" jsonschema:"relative file path"`
|
||||
Patch string `json:"patch" jsonschema:"unified diff patch"`
|
||||
}
|
||||
|
||||
type FSApplyPatchOutput struct {
|
||||
OK bool `json:"ok" jsonschema:"apply result"`
|
||||
}
|
||||
|
||||
func RegisterTools(server *sdkmcp.Server) {
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "echo", Description: "echo input text"}, echoTool)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.read", Description: "read file content"}, fsReadTool)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.write", Description: "write file content"}, fsWriteTool)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.list", Description: "list directory entries"}, fsListTool)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.stat", Description: "stat file or directory"}, fsStatTool)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.delete", Description: "delete file or directory"}, fsDeleteTool)
|
||||
sdkmcp.AddTool(server, &sdkmcp.Tool{Name: "fs.apply_patch", Description: "apply unified diff patch"}, fsApplyPatchTool)
|
||||
}
|
||||
|
||||
func echoTool(ctx context.Context, req *sdkmcp.CallToolRequest, input EchoInput) (
|
||||
*sdkmcp.CallToolResult,
|
||||
EchoOutput,
|
||||
error,
|
||||
) {
|
||||
return nil, EchoOutput{Text: input.Text}, nil
|
||||
}
|
||||
|
||||
func fsReadTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSReadInput) (
|
||||
*sdkmcp.CallToolResult,
|
||||
FSReadOutput,
|
||||
error,
|
||||
) {
|
||||
root := dataRoot()
|
||||
target, err := resolvePath(root, input.Path)
|
||||
if err != nil {
|
||||
return nil, FSReadOutput{}, err
|
||||
}
|
||||
data, err := os.ReadFile(target)
|
||||
if err != nil {
|
||||
return nil, FSReadOutput{}, err
|
||||
}
|
||||
return nil, FSReadOutput{Content: string(data)}, nil
|
||||
}
|
||||
|
||||
func fsWriteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSWriteInput) (
|
||||
*sdkmcp.CallToolResult,
|
||||
FSWriteOutput,
|
||||
error,
|
||||
) {
|
||||
root := dataRoot()
|
||||
target, err := resolvePath(root, input.Path)
|
||||
if err != nil {
|
||||
return nil, FSWriteOutput{}, err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
return nil, FSWriteOutput{}, err
|
||||
}
|
||||
if err := os.WriteFile(target, []byte(input.Content), 0o644); err != nil {
|
||||
return nil, FSWriteOutput{}, err
|
||||
}
|
||||
return nil, FSWriteOutput{OK: true}, nil
|
||||
}
|
||||
|
||||
func fsListTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSListInput) (
|
||||
*sdkmcp.CallToolResult,
|
||||
FSListOutput,
|
||||
error,
|
||||
) {
|
||||
root := dataRoot()
|
||||
target, err := resolvePathAllowRoot(root, input.Path)
|
||||
if err != nil {
|
||||
return nil, FSListOutput{}, err
|
||||
}
|
||||
info, err := os.Stat(target)
|
||||
if err != nil {
|
||||
return nil, FSListOutput{}, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, FSListOutput{}, fmt.Errorf("path is not a directory")
|
||||
}
|
||||
|
||||
entries := []FSFileEntry{}
|
||||
if input.Recursive {
|
||||
err = filepath.WalkDir(target, func(p string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if p == target {
|
||||
return nil
|
||||
}
|
||||
entryInfo, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry, err := entryForPath(root, p, entryInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
dirEntries, err := os.ReadDir(target)
|
||||
if err != nil {
|
||||
return nil, FSListOutput{}, err
|
||||
}
|
||||
for _, entry := range dirEntries {
|
||||
entryInfo, err := entry.Info()
|
||||
if err != nil {
|
||||
return nil, FSListOutput{}, err
|
||||
}
|
||||
fullPath := filepath.Join(target, entry.Name())
|
||||
fileEntry, err := entryForPath(root, fullPath, entryInfo)
|
||||
if err != nil {
|
||||
return nil, FSListOutput{}, err
|
||||
}
|
||||
entries = append(entries, fileEntry)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, FSListOutput{}, err
|
||||
}
|
||||
|
||||
listedPath := strings.TrimSpace(input.Path)
|
||||
if listedPath == "" {
|
||||
listedPath = "."
|
||||
}
|
||||
return nil, FSListOutput{Path: listedPath, Entries: entries}, nil
|
||||
}
|
||||
|
||||
func fsStatTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSStatInput) (
|
||||
*sdkmcp.CallToolResult,
|
||||
FSStatOutput,
|
||||
error,
|
||||
) {
|
||||
root := dataRoot()
|
||||
target, err := resolvePathAllowRoot(root, input.Path)
|
||||
if err != nil {
|
||||
return nil, FSStatOutput{}, err
|
||||
}
|
||||
info, err := os.Stat(target)
|
||||
if err != nil {
|
||||
return nil, FSStatOutput{}, err
|
||||
}
|
||||
entry, err := entryForPath(root, target, info)
|
||||
if err != nil {
|
||||
return nil, FSStatOutput{}, err
|
||||
}
|
||||
return nil, FSStatOutput{Entry: entry}, nil
|
||||
}
|
||||
|
||||
func fsDeleteTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSDeleteInput) (
|
||||
*sdkmcp.CallToolResult,
|
||||
FSDeleteOutput,
|
||||
error,
|
||||
) {
|
||||
root := dataRoot()
|
||||
target, err := resolvePath(root, input.Path)
|
||||
if err != nil {
|
||||
return nil, FSDeleteOutput{}, err
|
||||
}
|
||||
if err := os.RemoveAll(target); err != nil {
|
||||
return nil, FSDeleteOutput{}, err
|
||||
}
|
||||
return nil, FSDeleteOutput{OK: true}, nil
|
||||
}
|
||||
|
||||
func fsApplyPatchTool(ctx context.Context, req *sdkmcp.CallToolRequest, input FSApplyPatchInput) (
|
||||
*sdkmcp.CallToolResult,
|
||||
FSApplyPatchOutput,
|
||||
error,
|
||||
) {
|
||||
root := dataRoot()
|
||||
target, err := resolvePath(root, input.Path)
|
||||
if err != nil {
|
||||
return nil, FSApplyPatchOutput{}, err
|
||||
}
|
||||
orig, err := os.ReadFile(target)
|
||||
if err != nil {
|
||||
return nil, FSApplyPatchOutput{}, err
|
||||
}
|
||||
updated, err := applyUnifiedPatch(string(orig), input.Patch)
|
||||
if err != nil {
|
||||
return nil, FSApplyPatchOutput{}, err
|
||||
}
|
||||
info, err := os.Stat(target)
|
||||
if err != nil {
|
||||
return nil, FSApplyPatchOutput{}, err
|
||||
}
|
||||
if err := os.WriteFile(target, []byte(updated), info.Mode().Perm()); err != nil {
|
||||
return nil, FSApplyPatchOutput{}, err
|
||||
}
|
||||
return nil, FSApplyPatchOutput{OK: true}, nil
|
||||
}
|
||||
|
||||
func dataRoot() string {
|
||||
root := strings.TrimSpace(os.Getenv("MCP_DATA_DIR"))
|
||||
if root == "" {
|
||||
root = "/data"
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func resolvePathAllowRoot(root, requestPath string) (string, error) {
|
||||
if strings.TrimSpace(requestPath) == "" {
|
||||
return root, nil
|
||||
}
|
||||
return resolvePath(root, requestPath)
|
||||
}
|
||||
|
||||
func resolvePath(root, requestPath string) (string, error) {
|
||||
clean := filepath.Clean(requestPath)
|
||||
if clean == "." || clean == "" {
|
||||
return "", os.ErrInvalid
|
||||
}
|
||||
if filepath.IsAbs(clean) || strings.HasPrefix(clean, "..") {
|
||||
return "", os.ErrInvalid
|
||||
}
|
||||
return filepath.Join(root, clean), nil
|
||||
}
|
||||
|
||||
func entryForPath(root, target string, info os.FileInfo) (FSFileEntry, error) {
|
||||
rel, err := filepath.Rel(root, target)
|
||||
if err != nil {
|
||||
return FSFileEntry{}, err
|
||||
}
|
||||
if strings.HasPrefix(rel, "..") {
|
||||
return FSFileEntry{}, os.ErrInvalid
|
||||
}
|
||||
if rel == "." {
|
||||
rel = ""
|
||||
}
|
||||
return FSFileEntry{
|
||||
Path: filepath.ToSlash(rel),
|
||||
IsDir: info.IsDir(),
|
||||
Size: info.Size(),
|
||||
Mode: uint32(info.Mode().Perm()),
|
||||
ModTime: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func applyUnifiedPatch(original, patch string) (string, error) {
|
||||
lines := strings.Split(original, "\n")
|
||||
out := make([]string, 0, len(lines))
|
||||
index := 0
|
||||
patchLines := strings.Split(patch, "\n")
|
||||
hunksApplied := 0
|
||||
|
||||
for i := 0; i < len(patchLines); i++ {
|
||||
line := patchLines[i]
|
||||
if !strings.HasPrefix(line, "@@") {
|
||||
continue
|
||||
}
|
||||
|
||||
origStart, err := parseUnifiedHunkHeader(line)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
origStart--
|
||||
if origStart < 0 {
|
||||
origStart = 0
|
||||
}
|
||||
if origStart > len(lines) {
|
||||
return "", fmt.Errorf("patch out of range")
|
||||
}
|
||||
|
||||
out = append(out, lines[index:origStart]...)
|
||||
index = origStart
|
||||
hunksApplied++
|
||||
|
||||
for i+1 < len(patchLines) {
|
||||
next := patchLines[i+1]
|
||||
if strings.HasPrefix(next, "@@") {
|
||||
break
|
||||
}
|
||||
i++
|
||||
|
||||
if next == "" {
|
||||
if i == len(patchLines)-1 {
|
||||
break
|
||||
}
|
||||
return "", fmt.Errorf("invalid patch line")
|
||||
}
|
||||
if next[0] == '\\' {
|
||||
continue
|
||||
}
|
||||
op := next[0]
|
||||
text := next[1:]
|
||||
switch op {
|
||||
case ' ':
|
||||
if index >= len(lines) || lines[index] != text {
|
||||
return "", fmt.Errorf("patch context mismatch")
|
||||
}
|
||||
out = append(out, text)
|
||||
index++
|
||||
case '-':
|
||||
if index >= len(lines) || lines[index] != text {
|
||||
return "", fmt.Errorf("patch delete mismatch")
|
||||
}
|
||||
index++
|
||||
case '+':
|
||||
out = append(out, text)
|
||||
default:
|
||||
return "", fmt.Errorf("invalid patch operation")
|
||||
}
|
||||
}
|
||||
}
|
||||
if hunksApplied == 0 {
|
||||
return "", fmt.Errorf("patch contains no hunks")
|
||||
}
|
||||
|
||||
out = append(out, lines[index:]...)
|
||||
return strings.Join(out, "\n"), nil
|
||||
}
|
||||
|
||||
func parseUnifiedHunkHeader(header string) (int, error) {
|
||||
trimmed := strings.TrimPrefix(header, "@@")
|
||||
trimmed = strings.TrimSpace(trimmed)
|
||||
if !strings.HasPrefix(trimmed, "-") {
|
||||
return 0, fmt.Errorf("invalid hunk header")
|
||||
}
|
||||
parts := strings.SplitN(trimmed, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return 0, fmt.Errorf("invalid hunk header")
|
||||
}
|
||||
|
||||
origPart := strings.TrimPrefix(parts[0], "-")
|
||||
origFields := strings.SplitN(origPart, ",", 2)
|
||||
origStart, err := strconv.Atoi(origFields[0])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid hunk header")
|
||||
}
|
||||
return origStart, nil
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
|
||||
// SelectMemoryModel selects a chat model for memory operations.
|
||||
func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) {
|
||||
// First try to get the memory-enabled model.
|
||||
memoryModel, err := modelsService.GetByEnableAs(ctx, EnableAsMemory)
|
||||
if err == nil {
|
||||
provider, err := FetchProviderByID(ctx, queries, memoryModel.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return memoryModel, provider, nil
|
||||
}
|
||||
|
||||
// Fallback to chat model.
|
||||
chatModel, err := modelsService.GetByEnableAs(ctx, EnableAsChat)
|
||||
if err == nil {
|
||||
provider, err := FetchProviderByID(ctx, queries, chatModel.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return chatModel, provider, nil
|
||||
}
|
||||
|
||||
// If no enabled models, try to find any chat model.
|
||||
candidates, err := modelsService.ListByType(ctx, ModelTypeChat)
|
||||
if err != nil || len(candidates) == 0 {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations")
|
||||
}
|
||||
|
||||
selected := candidates[0]
|
||||
provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID)
|
||||
if err != nil {
|
||||
return GetResponse{}, sqlc.LlmProvider{}, err
|
||||
}
|
||||
return selected, provider, nil
|
||||
}
|
||||
|
||||
// FetchProviderByID fetches a provider by ID.
|
||||
func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) {
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
return sqlc.LlmProvider{}, fmt.Errorf("provider id missing")
|
||||
}
|
||||
parsed, err := parseUUID(providerID)
|
||||
if err != nil {
|
||||
return sqlc.LlmProvider{}, err
|
||||
}
|
||||
return queries.GetLlmProviderByID(ctx, parsed)
|
||||
}
|
||||
@@ -15,7 +15,7 @@ type Server struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, fsHandler *handlers.FSHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler) *Server {
|
||||
func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, swaggerHandler *handlers.SwaggerHandler, chatHandler *handlers.ChatHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, containerdHandler *handlers.ContainerdHandler) *Server {
|
||||
if addr == "" {
|
||||
addr = ":8080"
|
||||
}
|
||||
@@ -29,6 +29,9 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
|
||||
if path == "/ping" || path == "/api/swagger.json" || path == "/auth/login" {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(path, "/mcp/") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(path, "/api/docs") {
|
||||
return true
|
||||
}
|
||||
@@ -47,9 +50,6 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
|
||||
if embeddingsHandler != nil {
|
||||
embeddingsHandler.Register(e)
|
||||
}
|
||||
if fsHandler != nil {
|
||||
fsHandler.Register(e)
|
||||
}
|
||||
if swaggerHandler != nil {
|
||||
swaggerHandler.Register(e)
|
||||
}
|
||||
@@ -62,6 +62,9 @@ func NewServer(addr string, jwtSecret string, pingHandler *handlers.PingHandler,
|
||||
if modelsHandler != nil {
|
||||
modelsHandler.Register(e)
|
||||
}
|
||||
if containerdHandler != nil {
|
||||
containerdHandler.Register(e)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
echo: e,
|
||||
|
||||
Reference in New Issue
Block a user