diff --git a/cmd/agent/main.go b/cmd/agent/main.go index a8538273..3688f5f5 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -324,9 +324,9 @@ func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, lo } } -func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, containerdHandler *handlers.ContainerdHandler) *memprovider.Registry { +func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, manager *mcp.Manager) *memprovider.Registry { registry := memprovider.NewRegistry(log) - builtinRuntime := handlers.NewBuiltinMemoryRuntime(containerdHandler.FSService()) + builtinRuntime := handlers.NewBuiltinMemoryRuntime(manager) registry.RegisterFactory(memprovider.BuiltinType, func(id string, config map[string]any) (memprovider.Provider, error) { return memprovider.NewBuiltinProvider(log, builtinRuntime, chatService, accountService), nil }) @@ -481,7 +481,7 @@ func provideMemoryHandler(log *slog.Logger, botService *bots.Service, accountSer h := handlers.NewMemoryHandler(log, botService, accountService) h.SetMemoryRegistry(memoryRegistry) h.SetSettingsService(settingsService) - h.SetFSService(containerdHandler.FSService()) + h.SetMCPClientProvider(manager) return h } @@ -495,16 +495,9 @@ func provideMessageHandler(log *slog.Logger, chatService *conversation.Service, return h } -func provideMediaService(log *slog.Logger, cfg config.Config) (*media.Service, error) { - dataRoot := strings.TrimSpace(cfg.MCP.DataRoot) - if dataRoot == "" { - dataRoot = config.DefaultDataRoot - } - provider, err := containerfs.New(dataRoot) - if err != nil { - return nil, fmt.Errorf("init media provider: %w", err) - } - return media.NewService(log, provider), nil +func provideMediaService(log *slog.Logger, manager *mcp.Manager) *media.Service { + provider := containerfs.New(manager) + return media.NewService(log, provider) } func provideUsersHandler(log *slog.Logger, accountService *accounts.Service, identityService *identities.Service, botService *bots.Service, routeService *route.DBService, channelStore *channel.Store, channelLifecycle *channel.Lifecycle, channelManager *channel.Manager, registry *channel.Registry) *handlers.UsersHandler { @@ -638,7 +631,7 @@ func startContainerReconciliation(lc fx.Lifecycle, containerdHandler *handlers.C }) } -func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService, channelManager *channel.Manager, modelsService *models.Service) { +func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler, manager *mcp.Manager, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService, channelManager *channel.Manager, modelsService *models.Service) { fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo()) lc.Append(fx.Hook{ @@ -647,6 +640,10 @@ func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutd return err } botService.SetContainerLifecycle(containerdHandler) + botService.SetContainerReachability(func(ctx context.Context, botID string) error { + _, err := manager.MCPClient(ctx, botID) + return err + }) botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter( mcpchecker.NewChecker(logger, mcpConnService, toolGateway), )) diff --git a/cmd/mcp/Dockerfile b/cmd/mcp/Dockerfile deleted file mode 100644 index a37ba27c..00000000 --- a/cmd/mcp/Dockerfile +++ /dev/null @@ -1,30 +0,0 @@ -FROM golang:1.25-alpine AS build - -WORKDIR /src -COPY go.mod go.sum ./ -RUN go mod download - -COPY . . -ARG TARGETARCH -ARG COMMIT_HASH=unknown -RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH:-amd64} \ - go build -trimpath -ldflags "-s -w -X github.com/memohai/memoh/internal/version.CommitHash=${COMMIT_HASH}" -o /out/mcp ./cmd/mcp - -FROM alpine:latest - -# Base utilities -RUN apk add --no-cache grep curl bash - -# Node.js + npm (provides npx for JS/TS MCP servers) -RUN apk add --no-cache nodejs npm - -# Python 3 + uv (provides uvx for Python MCP servers) -RUN apk add --no-cache python3 && \ - curl -LsSf https://astral.sh/uv/install.sh | sh && \ - ln -sf /root/.local/bin/uv /usr/local/bin/uv && \ - ln -sf /root/.local/bin/uvx /usr/local/bin/uvx - -WORKDIR /app -COPY --from=build /out/mcp /opt/mcp -COPY cmd/mcp/template /opt/mcp-template -ENTRYPOINT ["/bin/sh","-lc","bootstrap(){ [ -e /app/mcp ] || { mkdir -p /app; [ -f /opt/mcp ] && cp -a /opt/mcp /app/mcp 2>/dev/null || true; }; }; bootstrap; if [ -x /app/mcp ]; then exec /app/mcp \"$@\"; fi; exec /opt/mcp \"$@\"","--"] diff --git a/cmd/mcp/entrypoint.sh b/cmd/mcp/entrypoint.sh new file mode 100644 index 00000000..118d32e4 --- /dev/null +++ b/cmd/mcp/entrypoint.sh @@ -0,0 +1,5 @@ +#!/bin/sh +# Copy binary to writable layer so it survives snapshot restores. +[ -e /app/mcp ] || { mkdir -p /app; [ -f /opt/mcp ] && cp -a /opt/mcp /app/mcp 2>/dev/null || true; } +if [ -x /app/mcp ]; then exec /app/mcp "$@"; fi +exec /opt/mcp "$@" diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index 9110eceb..a2f28684 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -2,41 +2,87 @@ package main import ( "context" - "errors" - "io" + "io/fs" "log/slog" + "net" "os" "os/signal" + "path/filepath" "syscall" "github.com/memohai/memoh/internal/logger" - "github.com/memohai/memoh/internal/version" - gomcp "github.com/modelcontextprotocol/go-sdk/mcp" + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" ) +const ( + defaultListenAddr = ":9090" + templateDir = "/opt/mcp-template" +) + +// initDataDir ensures /data exists and seeds template files on first boot. +func initDataDir() { + if err := os.MkdirAll(defaultWorkDir, 0o755); err != nil { + logger.Warn("failed to create data dir", slog.Any("error", err)) + return + } + + entries, err := os.ReadDir(templateDir) + if err != nil { + if !os.IsNotExist(err) { + logger.Warn("failed to read template dir", slog.String("dir", templateDir), slog.Any("error", err)) + } + return + } + for _, e := range entries { + if e.IsDir() { + continue + } + dst := filepath.Join(defaultWorkDir, e.Name()) + if _, err := os.Stat(dst); err == nil { + continue + } + data, err := os.ReadFile(filepath.Join(templateDir, e.Name())) + if err != nil { + continue + } + if err := os.WriteFile(dst, data, fs.FileMode(0o644)); err != nil { + logger.Warn("failed to seed template", slog.String("file", e.Name()), slog.Any("error", err)) + } + } +} + func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - // File tools (read/write/list/edit) are provided by the agent's MCP tool gateway, not this binary. - server := gomcp.NewServer( - &gomcp.Implementation{Name: "memoh-mcp", Version: version.GetInfo()}, - nil, - ) - err := server.Run(ctx, &gomcp.StdioTransport{}) - if ctx.Err() != nil { - return + initDataDir() + + addr := os.Getenv("MCP_LISTEN_ADDR") + if addr == "" { + addr = defaultListenAddr } - if err == nil { - logger.Warn("mcp server exited without error; waiting for shutdown signal") + + lis, err := net.Listen("tcp", addr) + if err != nil { + logger.Error("failed to listen", slog.String("addr", addr), slog.Any("error", err)) + os.Exit(1) + } + + srv := grpc.NewServer() + pb.RegisterContainerServiceServer(srv, &containerServer{}) + reflection.Register(srv) + + go func() { <-ctx.Done() - return + logger.Info("shutting down gRPC server") + srv.GracefulStop() + }() + + logger.Info("mcp gRPC server listening", slog.String("addr", addr)) + if err := srv.Serve(lis); err != nil { + logger.Error("gRPC server failed", slog.Any("error", err)) + os.Exit(1) } - if errors.Is(err, io.EOF) { - logger.Warn("mcp stdio closed; waiting for shutdown signal") - <-ctx.Done() - return - } - logger.Error("mcp server failed", slog.Any("error", err)) - os.Exit(1) } diff --git a/cmd/mcp/server.go b/cmd/mcp/server.go new file mode 100644 index 00000000..e368ad86 --- /dev/null +++ b/cmd/mcp/server.go @@ -0,0 +1,451 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "io/fs" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + "unicode/utf8" + + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + readMaxLines = 200 + readMaxBytes = 5120 + readMaxLineLen = 1000 + binaryProbeBytes = 8 * 1024 + rawChunkSize = 64 * 1024 + defaultWorkDir = "/data" + defaultTimeout = 30 +) + +type containerServer struct { + pb.UnimplementedContainerServiceServer +} + +func (s *containerServer) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) { + path := req.GetPath() + if path == "" { + return nil, status.Error(codes.InvalidArgument, "path is required") + } + path = resolvePath(path) + + f, err := os.Open(path) + if err != nil { + return nil, status.Errorf(codes.NotFound, "open: %v", err) + } + defer f.Close() + + probe := make([]byte, binaryProbeBytes) + n, _ := f.Read(probe) + if bytes.IndexByte(probe[:n], 0) >= 0 { + return &pb.ReadFileResponse{Binary: true}, nil + } + if _, err := f.Seek(0, io.SeekStart); err != nil { + return nil, status.Errorf(codes.Internal, "seek: %v", err) + } + + lineOffset := int(req.GetLineOffset()) + if lineOffset < 1 { + lineOffset = 1 + } + nLines := int(req.GetNLines()) + if nLines < 1 || nLines > readMaxLines { + nLines = readMaxLines + } + + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + currentLine := 0 + totalLines := 0 + var out strings.Builder + linesRead := 0 + bytesWritten := 0 + + for scanner.Scan() { + currentLine++ + totalLines = currentLine + if currentLine < lineOffset { + continue + } + if linesRead >= nLines { + continue // keep scanning to count total lines + } + + line := scanner.Text() + if utf8.RuneCountInString(line) > readMaxLineLen { + line = truncateRunes(line, readMaxLineLen) + "..." + } + + formatted := fmt.Sprintf("%6d\t%s\n", currentLine, line) + if bytesWritten+len(formatted) > readMaxBytes { + break + } + out.WriteString(formatted) + bytesWritten += len(formatted) + linesRead++ + } + + // Drain remaining lines for total count. + for scanner.Scan() { + totalLines++ + } + + return &pb.ReadFileResponse{ + Content: out.String(), + TotalLines: int32(totalLines), + }, nil +} + +func (s *containerServer) WriteFile(_ context.Context, req *pb.WriteFileRequest) (*pb.WriteFileResponse, error) { + path := req.GetPath() + if path == "" { + return nil, status.Error(codes.InvalidArgument, "path is required") + } + path = resolvePath(path) + + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return nil, status.Errorf(codes.Internal, "mkdir: %v", err) + } + if err := os.WriteFile(path, req.GetContent(), 0o644); err != nil { + return nil, status.Errorf(codes.Internal, "write: %v", err) + } + return &pb.WriteFileResponse{}, nil +} + +func (s *containerServer) ListDir(_ context.Context, req *pb.ListDirRequest) (*pb.ListDirResponse, error) { + dir := req.GetPath() + if dir == "" { + dir = "." + } + dir = resolvePath(dir) + + var entries []*pb.FileEntry + + if req.GetRecursive() { + err := filepath.WalkDir(dir, func(p string, d fs.DirEntry, err error) error { + if err != nil { + return nil // skip errors + } + rel, _ := filepath.Rel(dir, p) + if rel == "." { + return nil + } + entry, _ := buildFileEntry(rel, p, d) + if entry != nil { + entries = append(entries, entry) + } + return nil + }) + if err != nil { + return nil, status.Errorf(codes.NotFound, "walk: %v", err) + } + } else { + dirEntries, err := os.ReadDir(dir) + if err != nil { + return nil, status.Errorf(codes.NotFound, "readdir: %v", err) + } + for _, d := range dirEntries { + entry, _ := buildFileEntry(d.Name(), filepath.Join(dir, d.Name()), d) + if entry != nil { + entries = append(entries, entry) + } + } + } + + return &pb.ListDirResponse{Entries: entries}, nil +} + +func (s *containerServer) Exec(stream pb.ContainerService_ExecServer) error { + // Receive first message to get command details + firstMsg, err := stream.Recv() + if err != nil { + return status.Error(codes.InvalidArgument, "failed to receive exec config") + } + + command := firstMsg.GetCommand() + if command == "" { + return status.Error(codes.InvalidArgument, "command is required") + } + + workDir := firstMsg.GetWorkDir() + if workDir == "" { + workDir = defaultWorkDir + } + + timeout := int(firstMsg.GetTimeoutSeconds()) + if timeout <= 0 { + timeout = defaultTimeout + } + + ctx, cancel := context.WithTimeout(stream.Context(), time.Duration(timeout)*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) + cmd.Dir = workDir + if len(firstMsg.GetEnv()) > 0 { + cmd.Env = append(os.Environ(), firstMsg.GetEnv()...) + } + + // Setup stdin pipe for bidirectional streaming + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return status.Errorf(codes.Internal, "stdin pipe: %v", err) + } + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return status.Errorf(codes.Internal, "stdout pipe: %v", err) + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + return status.Errorf(codes.Internal, "stderr pipe: %v", err) + } + + if err := cmd.Start(); err != nil { + return status.Errorf(codes.Internal, "start: %v", err) + } + + // Handle stdin from stream + go func() { + for { + msg, err := stream.Recv() + if err != nil { + _ = stdinPipe.Close() + return + } + if data := msg.GetStdinData(); len(data) > 0 { + _, _ = stdinPipe.Write(data) + } + } + }() + + // Stream stdout/stderr to client + done := make(chan struct{}) + go func() { + defer close(done) + streamPipe(stream, stdoutPipe, pb.ExecOutput_STDOUT) + }() + streamPipe(stream, stderrPipe, pb.ExecOutput_STDERR) + <-done + + exitCode := int32(0) + if err := cmd.Wait(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = int32(exitErr.ExitCode()) + } else { + exitCode = -1 + } + } + + return stream.Send(&pb.ExecOutput{ + Stream: pb.ExecOutput_EXIT, + ExitCode: exitCode, + }) +} + +func (s *containerServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { + path := req.GetPath() + if path == "" { + return status.Error(codes.InvalidArgument, "path is required") + } + path = resolvePath(path) + + f, err := os.Open(path) + if err != nil { + return status.Errorf(codes.NotFound, "open: %v", err) + } + defer f.Close() + + buf := make([]byte, rawChunkSize) + for { + n, err := f.Read(buf) + if n > 0 { + if sendErr := stream.Send(&pb.DataChunk{Data: buf[:n]}); sendErr != nil { + return sendErr + } + } + if err == io.EOF { + break + } + if err != nil { + return status.Errorf(codes.Internal, "read: %v", err) + } + } + return nil +} + +func (s *containerServer) WriteRaw(stream pb.ContainerService_WriteRawServer) error { + var f *os.File + var written int64 + + for { + chunk, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return err + } + + if f == nil { + path := chunk.GetPath() + if path == "" { + return status.Error(codes.InvalidArgument, "first chunk must include path") + } + path = resolvePath(path) + if mkErr := os.MkdirAll(filepath.Dir(path), 0o755); mkErr != nil { + return status.Errorf(codes.Internal, "mkdir: %v", mkErr) + } + f, err = os.Create(path) + if err != nil { + return status.Errorf(codes.Internal, "create: %v", err) + } + defer f.Close() + } + + if len(chunk.GetData()) > 0 { + n, err := f.Write(chunk.GetData()) + written += int64(n) + if err != nil { + return status.Errorf(codes.Internal, "write: %v", err) + } + } + } + + return stream.SendAndClose(&pb.WriteRawResponse{BytesWritten: written}) +} + +func (s *containerServer) DeleteFile(_ context.Context, req *pb.DeleteFileRequest) (*pb.DeleteFileResponse, error) { + path := req.GetPath() + if path == "" { + return nil, status.Error(codes.InvalidArgument, "path is required") + } + path = resolvePath(path) + + var err error + if req.GetRecursive() { + err = os.RemoveAll(path) + } else { + err = os.Remove(path) + } + if err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "delete: %v", err) + } + return &pb.DeleteFileResponse{}, nil +} + +func (s *containerServer) Stat(_ context.Context, req *pb.StatRequest) (*pb.StatResponse, error) { + path := req.GetPath() + if path == "" { + return nil, status.Error(codes.InvalidArgument, "path is required") + } + path = resolvePath(path) + + info, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return nil, status.Error(codes.NotFound, "not found") + } + return nil, status.Errorf(codes.Internal, "stat: %v", err) + } + return &pb.StatResponse{ + Entry: &pb.FileEntry{ + Path: filepath.Base(path), + IsDir: info.IsDir(), + Size: info.Size(), + Mode: info.Mode().String(), + ModTime: info.ModTime().Format(time.RFC3339), + }, + }, nil +} + +func (s *containerServer) Mkdir(_ context.Context, req *pb.MkdirRequest) (*pb.MkdirResponse, error) { + path := req.GetPath() + if path == "" { + return nil, status.Error(codes.InvalidArgument, "path is required") + } + path = resolvePath(path) + + if err := os.MkdirAll(path, 0o755); err != nil { + return nil, status.Errorf(codes.Internal, "mkdir: %v", err) + } + return &pb.MkdirResponse{}, nil +} + +func (s *containerServer) Rename(_ context.Context, req *pb.RenameRequest) (*pb.RenameResponse, error) { + oldPath := req.GetOldPath() + newPath := req.GetNewPath() + if oldPath == "" || newPath == "" { + return nil, status.Error(codes.InvalidArgument, "old_path and new_path are required") + } + oldPath = resolvePath(oldPath) + newPath = resolvePath(newPath) + + if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil { + return nil, status.Errorf(codes.Internal, "mkdir parent: %v", err) + } + if err := os.Rename(oldPath, newPath); err != nil { + return nil, status.Errorf(codes.Internal, "rename: %v", err) + } + return &pb.RenameResponse{}, nil +} + +func streamPipe(stream pb.ContainerService_ExecServer, r io.Reader, st pb.ExecOutput_Stream) { + buf := make([]byte, 4096) + for { + n, err := r.Read(buf) + if n > 0 { + _ = stream.Send(&pb.ExecOutput{ + Stream: st, + Data: buf[:n], + }) + } + if err != nil { + break + } + } +} + +func buildFileEntry(name, fullPath string, d fs.DirEntry) (*pb.FileEntry, error) { + info, err := d.Info() + if err != nil { + return nil, err + } + return &pb.FileEntry{ + Path: name, + IsDir: d.IsDir(), + Size: info.Size(), + Mode: info.Mode().String(), + ModTime: info.ModTime().Format(time.RFC3339), + }, nil +} + +func resolvePath(path string) string { + if filepath.IsAbs(path) { + return filepath.Clean(path) + } + return filepath.Join(defaultWorkDir, path) +} + +func truncateRunes(s string, max int) string { + pos := 0 + count := 0 + for pos < len(s) && count < max { + _, size := utf8.DecodeRuneInString(s[pos:]) + pos += size + count++ + } + return s[:pos] +} diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 8be69c5b..290b7889 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -220,9 +220,9 @@ func provideAgentRuntimeManager(log *slog.Logger, cfg config.Config) *agentrunti func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memprovider.LLM { return &lazyLLMClient{modelsService: modelsService, queries: queries, timeout: 30 * time.Second, logger: log} } -func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, containerdHandler *handlers.ContainerdHandler) *memprovider.Registry { +func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, manager *mcp.Manager) *memprovider.Registry { registry := memprovider.NewRegistry(log) - builtinRuntime := handlers.NewBuiltinMemoryRuntime(containerdHandler.FSService()) + builtinRuntime := handlers.NewBuiltinMemoryRuntime(manager) registry.RegisterFactory(memprovider.BuiltinType, func(id string, config map[string]any) (memprovider.Provider, error) { return memprovider.NewBuiltinProvider(log, builtinRuntime, chatService, accountService), nil }) @@ -335,7 +335,7 @@ func provideMemoryHandler(log *slog.Logger, botService *bots.Service, accountSer h := handlers.NewMemoryHandler(log, botService, accountService) h.SetMemoryRegistry(memoryRegistry) h.SetSettingsService(settingsService) - h.SetFSService(containerdHandler.FSService()) + h.SetMCPClientProvider(manager) return h } func provideAuthHandler(log *slog.Logger, accountService *accounts.Service, rc *boot.RuntimeConfig) *handlers.AuthHandler { @@ -356,16 +356,9 @@ func (h *memohAuthHandler) Register(e *echo.Echo) { e.POST("/api/auth/login", h.inner.Login) e.POST("/api/auth/refresh", h.inner.Refresh) } -func provideMediaService(log *slog.Logger, cfg config.Config) (*media.Service, error) { - dataRoot := strings.TrimSpace(cfg.MCP.DataRoot) - if dataRoot == "" { - dataRoot = config.DefaultDataRoot - } - provider, err := containerfs.New(dataRoot) - if err != nil { - return nil, fmt.Errorf("init media provider: %w", err) - } - return media.NewService(log, provider), nil +func provideMediaService(log *slog.Logger, manager *mcp.Manager) *media.Service { + provider := containerfs.New(manager) + return media.NewService(log, provider) } func provideUsersHandler(log *slog.Logger, accountService *accounts.Service, identityService *identities.Service, botService *bots.Service, routeService *route.DBService, channelStore *channel.Store, channelLifecycle *channel.Lifecycle, channelManager *channel.Manager, registry *channel.Registry) *handlers.UsersHandler { return handlers.NewUsersHandler(log, accountService, identityService, botService, routeService, channelStore, channelLifecycle, channelManager, registry) @@ -506,7 +499,7 @@ func startAgentRuntime(lc fx.Lifecycle, manager *agentruntime.Manager) { OnStop: func(ctx context.Context) error { return manager.Stop(ctx) }, }) } -func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *memohServer, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService, channelManager *channel.Manager) { +func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *memohServer, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler, manager *mcp.Manager, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService, channelManager *channel.Manager) { fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo()) lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { @@ -514,6 +507,10 @@ func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *memohServer, shutdow return err } botService.SetContainerLifecycle(containerdHandler) + botService.SetContainerReachability(func(ctx context.Context, botID string) error { + _, err := manager.MCPClient(ctx, botID) + return err + }) botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(mcpchecker.NewChecker(logger, mcpConnService, toolGateway))) botService.AddRuntimeChecker(healthcheck.NewRuntimeCheckerAdapter(channelchecker.NewChecker(logger, channelManager))) go func() { diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 61d3c385..0a316402 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -314,7 +314,6 @@ CREATE TABLE IF NOT EXISTS containers ( status TEXT NOT NULL DEFAULT 'created', namespace TEXT NOT NULL DEFAULT 'default', auto_start BOOLEAN NOT NULL DEFAULT true, - host_path TEXT, container_path TEXT NOT NULL DEFAULT '/data', created_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), diff --git a/db/migrations/0024_drop_host_path.down.sql b/db/migrations/0024_drop_host_path.down.sql new file mode 100644 index 00000000..539c6074 --- /dev/null +++ b/db/migrations/0024_drop_host_path.down.sql @@ -0,0 +1,3 @@ +-- 0024_drop_host_path (rollback) +-- Re-add host_path column to containers table +ALTER TABLE containers ADD COLUMN IF NOT EXISTS host_path TEXT; diff --git a/db/migrations/0024_drop_host_path.up.sql b/db/migrations/0024_drop_host_path.up.sql new file mode 100644 index 00000000..9becc778 --- /dev/null +++ b/db/migrations/0024_drop_host_path.up.sql @@ -0,0 +1,3 @@ +-- 0024_drop_host_path +-- Remove host_path column from containers table (replaced by gRPC container access) +ALTER TABLE containers DROP COLUMN IF EXISTS host_path; diff --git a/db/queries/containers.sql b/db/queries/containers.sql index 0917a91a..132488da 100644 --- a/db/queries/containers.sql +++ b/db/queries/containers.sql @@ -1,7 +1,7 @@ -- name: UpsertContainer :exec INSERT INTO containers ( bot_id, container_id, container_name, image, status, namespace, auto_start, - host_path, container_path, last_started_at, last_stopped_at + container_path, last_started_at, last_stopped_at ) VALUES ( sqlc.arg(bot_id), @@ -11,7 +11,6 @@ VALUES ( sqlc.arg(status), sqlc.arg(namespace), sqlc.arg(auto_start), - sqlc.arg(host_path), sqlc.arg(container_path), sqlc.arg(last_started_at), sqlc.arg(last_stopped_at) @@ -23,7 +22,6 @@ ON CONFLICT (container_id) DO UPDATE SET status = EXCLUDED.status, namespace = EXCLUDED.namespace, auto_start = EXCLUDED.auto_start, - host_path = EXCLUDED.host_path, container_path = EXCLUDED.container_path, last_started_at = EXCLUDED.last_started_at, last_stopped_at = EXCLUDED.last_stopped_at, diff --git a/devenv/mcp-build.sh b/devenv/mcp-build.sh index 2e4bbd37..f4acab56 100755 --- a/devenv/mcp-build.sh +++ b/devenv/mcp-build.sh @@ -28,10 +28,13 @@ LAYER1_SHA=$(sha256sum "$BASE_ROOTFS" | cut -d' ' -f1) mkdir -p "$WORK/$LAYER1_SHA" ln -s "$BASE_ROOTFS" "$WORK/$LAYER1_SHA/layer.tar" -# Layer 2: compiled binary overlay +# Layer 2: compiled binary + template + entrypoint overlay mkdir -p "$WORK/overlay/opt" cp "$MCP_BINARY" "$WORK/overlay/opt/mcp" chmod +x "$WORK/overlay/opt/mcp" +cp -a /workspace/cmd/mcp/template "$WORK/overlay/opt/mcp-template" +cp /workspace/cmd/mcp/entrypoint.sh "$WORK/overlay/opt/entrypoint.sh" +chmod +x "$WORK/overlay/opt/entrypoint.sh" tar -cf "$WORK/layer2.tar" -C "$WORK/overlay" opt LAYER2_SHA=$(sha256sum "$WORK/layer2.tar" | cut -d' ' -f1) mkdir -p "$WORK/$LAYER2_SHA" diff --git a/docker/Dockerfile.mcp b/docker/Dockerfile.mcp index 2de72f1f..e6b1c337 100644 --- a/docker/Dockerfile.mcp +++ b/docker/Dockerfile.mcp @@ -38,4 +38,6 @@ RUN apk add --no-cache python3 && \ WORKDIR /app COPY --from=build /out/mcp /opt/mcp COPY cmd/mcp/template /opt/mcp-template -ENTRYPOINT ["/bin/sh","-lc","bootstrap(){ [ -e /app/mcp ] || { mkdir -p /app; [ -f /opt/mcp ] && cp -a /opt/mcp /app/mcp 2>/dev/null || true; }; }; bootstrap; if [ -x /app/mcp ]; then exec /app/mcp \"$@\"; fi; exec /opt/mcp \"$@\"","--"] +COPY cmd/mcp/entrypoint.sh /opt/entrypoint.sh +RUN chmod +x /opt/entrypoint.sh +ENTRYPOINT ["/opt/entrypoint.sh"] diff --git a/go.mod b/go.mod index d1ec9d81..7371edc2 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,8 @@ require ( github.com/wneessen/go-mail v0.7.2 go.uber.org/fx v1.24.0 golang.org/x/crypto v0.48.0 + google.golang.org/grpc v1.78.0 + google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 ) @@ -126,6 +128,4 @@ require ( golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.42.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect - google.golang.org/grpc v1.78.0 // indirect - google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/internal/bots/service.go b/internal/bots/service.go index e4eb4319..da852385 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "log/slog" - "os" "strings" "time" @@ -20,10 +19,11 @@ import ( // Service provides bot CRUD and membership management. type Service struct { - queries *sqlc.Queries - logger *slog.Logger - containerLifecycle ContainerLifecycle - checkers []RuntimeChecker + queries *sqlc.Queries + logger *slog.Logger + containerLifecycle ContainerLifecycle + checkers []RuntimeChecker + containerReachability func(ctx context.Context, botID string) error } const ( @@ -58,6 +58,12 @@ func (s *Service) SetContainerLifecycle(lc ContainerLifecycle) { s.containerLifecycle = lc } +// SetContainerReachability registers a function that checks whether a bot's +// container is reachable via gRPC. Returns nil on success, error otherwise. +func (s *Service) SetContainerReachability(fn func(ctx context.Context, botID string) error) { + s.containerReachability = fn +} + // AddRuntimeChecker registers an additional runtime checker. func (s *Service) AddRuntimeChecker(c RuntimeChecker) { if c != nil { @@ -750,8 +756,8 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot, includeD Type: BotCheckTypeContainerData, TitleKey: "bots.checks.titles.containerDataPath", Status: BotCheckStatusUnknown, - Summary: "Container host path check is pending.", - Detail: "Data path will be checked after initialization.", + Summary: "Container reachability check is pending.", + Detail: "Reachability will be checked after initialization.", }) if includeDynamic { checks = s.appendDynamicChecks(ctx, row.ID.String(), checks) @@ -788,8 +794,8 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot, includeD Type: BotCheckTypeContainerData, TitleKey: "bots.checks.titles.containerDataPath", Status: BotCheckStatusUnknown, - Summary: "Container host path check is skipped.", - Detail: "Bot is deleting and data path checks are paused.", + Summary: "Container reachability check is skipped.", + Detail: "Bot is deleting and reachability checks are paused.", }) if includeDynamic { checks = s.appendDynamicChecks(ctx, row.ID.String(), checks) @@ -824,14 +830,14 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot, includeD Summary: "Container task state is unknown.", Detail: "Task state cannot be determined without a container record.", }) - checks = append(checks, BotCheck{ - ID: BotCheckTypeContainerData, - Type: BotCheckTypeContainerData, - TitleKey: "bots.checks.titles.containerDataPath", - Status: BotCheckStatusUnknown, - Summary: "Container data path is unknown.", - Detail: "Data path cannot be determined without a container record.", - }) + checks = append(checks, BotCheck{ + ID: BotCheckTypeContainerData, + Type: BotCheckTypeContainerData, + TitleKey: "bots.checks.titles.containerDataPath", + Status: BotCheckStatusUnknown, + Summary: "Container reachability is unknown.", + Detail: "Reachability cannot be determined without a container record.", + }) if includeDynamic { checks = s.appendDynamicChecks(ctx, row.ID.String(), checks) } @@ -875,44 +881,23 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot, includeD taskCheck.Metadata = map[string]any{"status": taskStatus} checks = append(checks, taskCheck) - hostPath := "" - if containerRow.HostPath.Valid { - hostPath = strings.TrimSpace(containerRow.HostPath.String) - } dataCheck := BotCheck{ ID: BotCheckTypeContainerData, Type: BotCheckTypeContainerData, TitleKey: "bots.checks.titles.containerDataPath", Status: BotCheckStatusWarn, - Summary: "Container host path needs attention.", - Metadata: map[string]any{"host_path": hostPath}, + Summary: "Container reachability needs attention.", } - if hostPath == "" { - dataCheck.Detail = "host path is empty" - checks = append(checks, dataCheck) - if includeDynamic { - checks = s.appendDynamicChecks(ctx, row.ID.String(), checks) - } - return checks, nil - } - info, statErr := os.Stat(hostPath) - switch { - case statErr == nil && info != nil && info.IsDir(): + if s.containerReachability == nil { + dataCheck.Status = BotCheckStatusUnknown + dataCheck.Summary = "Container reachability check not configured." + } else if err := s.containerReachability(ctx, row.ID.String()); err != nil { + dataCheck.Status = BotCheckStatusError + dataCheck.Summary = "Container is not reachable via gRPC." + dataCheck.Detail = err.Error() + } else { dataCheck.Status = BotCheckStatusOK - dataCheck.Summary = "Container host path is accessible." - dataCheck.Detail = hostPath - case statErr == nil: - dataCheck.Status = BotCheckStatusError - dataCheck.Summary = "Container host path is invalid." - dataCheck.Detail = "host path is not a directory" - case errors.Is(statErr, os.ErrNotExist): - dataCheck.Status = BotCheckStatusError - dataCheck.Summary = "Container host path does not exist." - dataCheck.Detail = hostPath - default: - dataCheck.Status = BotCheckStatusWarn - dataCheck.Summary = "Container host path cannot be checked." - dataCheck.Detail = statErr.Error() + dataCheck.Summary = "Container is reachable via gRPC." } checks = append(checks, dataCheck) if includeDynamic { diff --git a/internal/config/config.go b/internal/config/config.go index 27baea30..c98edc78 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,9 @@ const ( DefaultPGUser = "postgres" DefaultPGDatabase = "memoh" DefaultPGSSLMode = "disable" + DefaultQdrantURL = "http://127.0.0.1:6334" + DefaultQdrantCollection = "memory" + MCPGRPCPort = 9090 ) type Config struct { diff --git a/internal/containerd/network.go b/internal/containerd/network.go index b32e4032..95cd7beb 100644 --- a/internal/containerd/network.go +++ b/internal/containerd/network.go @@ -11,31 +11,31 @@ import ( gocni "github.com/containerd/go-cni" ) -func setupCNINetwork(ctx context.Context, task client.Task, containerID string, CNIBinDir string, CNIConfDir string) error { +func setupCNINetwork(ctx context.Context, task client.Task, containerID string, CNIBinDir string, CNIConfDir string) (string, error) { if task == nil { - return ErrInvalidArgument + return "", ErrInvalidArgument } if containerID == "" { containerID = task.ID() } if containerID == "" { - return ErrInvalidArgument + return "", ErrInvalidArgument } pid := task.Pid() if pid == 0 { - return fmt.Errorf("task pid not available for %s", containerID) + return "", fmt.Errorf("task pid not available for %s", containerID) } if _, err := os.Stat(CNIConfDir); err != nil { - return fmt.Errorf("cni config dir missing: %s: %w", CNIConfDir, err) + return "", fmt.Errorf("cni config dir missing: %s: %w", CNIConfDir, err) } if _, err := os.Stat(CNIBinDir); err != nil { - return fmt.Errorf("cni bin dir missing: %s: %w", CNIBinDir, err) + return "", fmt.Errorf("cni bin dir missing: %s: %w", CNIBinDir, err) } netnsPath := filepath.Join("/proc", fmt.Sprint(pid), "ns", "net") if _, err := os.Stat(netnsPath); err != nil { - return fmt.Errorf("netns not found: %s: %w", netnsPath, err) + return "", fmt.Errorf("netns not found: %s: %w", netnsPath, err) } cni, err := gocni.New( @@ -43,26 +43,43 @@ func setupCNINetwork(ctx context.Context, task client.Task, containerID string, gocni.WithPluginConfDir(CNIConfDir), ) if err != nil { - return err + return "", err } if err := cni.Load(gocni.WithLoNetwork, gocni.WithDefaultConf); err != nil { - return err + return "", err } - _, err = cni.Setup(ctx, containerID, netnsPath) + result, err := cni.Setup(ctx, containerID, netnsPath) if err != nil { - if !isDuplicateAllocationError(err) { - return err + if !isDuplicateAllocationError(err) && !isVethExistsError(err) { + return "", err } - // Stale IPAM allocation (e.g. after container restart with persisted + // Stale IPAM allocation or veth exists (e.g. after container restart with persisted // /var/lib/cni). Remove may fail if the previous iptables/veth state // is already gone; ignore the error so the retry Setup still runs. _ = cni.Remove(ctx, containerID, netnsPath) - _, err = cni.Setup(ctx, containerID, netnsPath) + result, err = cni.Setup(ctx, containerID, netnsPath) if err != nil { - return err + return "", err } } - return nil + return extractIP(result), nil +} + +func extractIP(result *gocni.Result) string { + if result == nil { + return "" + } + for _, cfg := range result.Interfaces { + for _, ipCfg := range cfg.IPConfigs { + if ipCfg.IP != nil { + ip := ipCfg.IP.String() + if ip != "" && ip != "127.0.0.1" && ip != "::1" { + return ip + } + } + } + } + return "" } func removeCNINetwork(ctx context.Context, task client.Task, containerID string, CNIBinDir string, CNIConfDir string) error { @@ -112,3 +129,12 @@ func isDuplicateAllocationError(err error) bool { } return strings.Contains(err.Error(), "duplicate allocation") } + +// isVethExistsError returns true if the CNI setup failed because veth devices +// already exist (e.g. after container restart with stale network state). +func isVethExistsError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "already exists") +} diff --git a/internal/containerd/service.go b/internal/containerd/service.go index 5299a616..702447e4 100644 --- a/internal/containerd/service.go +++ b/internal/containerd/service.go @@ -4,9 +4,7 @@ import ( "context" "errors" "fmt" - "io" "log/slog" - "os" "runtime" "strings" "syscall" @@ -54,9 +52,7 @@ type DeleteContainerOptions struct { } type StartTaskOptions struct { - UseStdio bool Terminal bool - FIFODir string } type StopTaskOptions struct { @@ -95,10 +91,7 @@ type Service interface { DeleteTask(ctx context.Context, containerID string, opts *DeleteTaskOptions) error GetTaskInfo(ctx context.Context, containerID string) (TaskInfo, error) ListTasks(ctx context.Context, opts *ListTasksOptions) ([]TaskInfo, error) - ExecTask(ctx context.Context, containerID string, req ExecTaskRequest) (ExecTaskResult, error) - ExecTaskStreaming(ctx context.Context, containerID string, req ExecTaskRequest) (*ExecTaskSession, error) - - SetupNetwork(ctx context.Context, req NetworkSetupRequest) error + SetupNetwork(ctx context.Context, req NetworkSetupRequest) (NetworkResult, error) RemoveNetwork(ctx context.Context, req NetworkSetupRequest) error CommitSnapshot(ctx context.Context, snapshotter, name, key string) error @@ -440,21 +433,7 @@ func (s *DefaultService) StartContainer(ctx context.Context, containerID string, return err } - var ioCreator cio.Creator - if opts == nil || !opts.UseStdio { - ioCreator = cio.NullIO - } else { - cioOpts := []cio.Opt{cio.WithStdio} - if opts.Terminal { - cioOpts = append(cioOpts, cio.WithTerminal) - } - if opts.FIFODir != "" { - cioOpts = append(cioOpts, cio.WithFIFODir(opts.FIFODir)) - } - ioCreator = cio.NewCreator(cioOpts...) - } - - task, err := container.NewTask(ctx, ioCreator) + task, err := container.NewTask(ctx, cio.NullIO) if err != nil { return err } @@ -581,231 +560,23 @@ func (s *DefaultService) DeleteTask(ctx context.Context, containerID string, opt } if opts != nil && opts.Force { + // Kill and wait for exit before deleting; containerd rejects Delete on a + // still-running process even when force is requested. _ = task.Kill(ctx, syscall.SIGKILL) + if statusC, waitErr := task.Wait(ctx); waitErr == nil { + waitCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + select { + case <-statusC: + case <-waitCtx.Done(): + } + } } _, err = task.Delete(ctx) return err } -func (s *DefaultService) ExecTask(ctx context.Context, containerID string, req ExecTaskRequest) (ExecTaskResult, error) { - if containerID == "" || len(req.Args) == 0 { - return ExecTaskResult{}, ErrInvalidArgument - } - - ctx = s.withNamespace(ctx) - container, err := s.client.LoadContainer(ctx, containerID) - if err != nil { - return ExecTaskResult{}, err - } - - spec, err := container.Spec(ctx) - if err != nil { - return ExecTaskResult{}, err - } - if spec.Process == nil { - spec.Process = &specs.Process{} - } - - if len(req.Env) > 0 { - if err := oci.WithEnv(req.Env)(ctx, nil, nil, spec); err != nil { - return ExecTaskResult{}, err - } - } - - spec.Process.Args = req.Args - if req.WorkDir != "" { - spec.Process.Cwd = req.WorkDir - } - if req.Terminal { - spec.Process.Terminal = true - } - - task, err := s.getTask(ctx, containerID) - if err != nil { - return ExecTaskResult{}, err - } - - ioOpts := []cio.Opt{} - if req.Stdin != nil || req.Stdout != nil || req.Stderr != nil { - ioOpts = append(ioOpts, cio.WithStreams(req.Stdin, req.Stdout, req.Stderr)) - } else if req.UseStdio { - ioOpts = append(ioOpts, cio.WithStdio) - } - if req.Terminal { - ioOpts = append(ioOpts, cio.WithTerminal) - } - if strings.TrimSpace(req.FIFODir) != "" { - if err := os.MkdirAll(req.FIFODir, 0o755); err != nil { - return ExecTaskResult{}, err - } - ioOpts = append(ioOpts, cio.WithFIFODir(req.FIFODir)) - } - ioCreator := cio.NewCreator(ioOpts...) - - execID := fmt.Sprintf("exec-%d", time.Now().UnixNano()) - process, err := task.Exec(ctx, execID, spec.Process, ioCreator) - if err != nil { - return ExecTaskResult{}, err - } - defer process.Delete(ctx) - - statusC, err := process.Wait(ctx) - if err != nil { - return ExecTaskResult{}, err - } - if err := process.Start(ctx); err != nil { - return ExecTaskResult{}, err - } - - status := <-statusC - code, _, err := status.Result() - if err != nil { - return ExecTaskResult{}, err - } - - return ExecTaskResult{ExitCode: code}, nil -} - -func (s *DefaultService) ExecTaskStreaming(ctx context.Context, containerID string, req ExecTaskRequest) (*ExecTaskSession, error) { - if containerID == "" || len(req.Args) == 0 { - return nil, ErrInvalidArgument - } - - ctx = s.withNamespace(ctx) - container, err := s.client.LoadContainer(ctx, containerID) - if err != nil { - return nil, err - } - - spec, err := container.Spec(ctx) - if err != nil { - return nil, err - } - if spec.Process == nil { - spec.Process = &specs.Process{} - } - if len(req.Env) > 0 { - if err := oci.WithEnv(req.Env)(ctx, nil, nil, spec); err != nil { - return nil, err - } - } - spec.Process.Args = req.Args - if req.WorkDir != "" { - spec.Process.Cwd = req.WorkDir - } - if req.Terminal { - spec.Process.Terminal = true - } - - task, err := s.getTask(ctx, containerID) - if err != nil { - return nil, err - } - - stdinR, stdinW := io.Pipe() - stdoutR, stdoutW := io.Pipe() - stderrR, stderrW := io.Pipe() - - ioOpts := []cio.Opt{ - cio.WithStreams(stdinR, stdoutW, stderrW), - } - if req.Terminal { - ioOpts = append(ioOpts, cio.WithTerminal) - } - fifoDir, err := resolveExecFIFODir(req.FIFODir) - if err != nil { - _ = stdinR.Close() - _ = stdinW.Close() - _ = stdoutR.Close() - _ = stdoutW.Close() - _ = stderrR.Close() - _ = stderrW.Close() - return nil, err - } - ioOpts = append(ioOpts, cio.WithFIFODir(fifoDir)) - ioCreator := cio.NewCreator(ioOpts...) - - execID := fmt.Sprintf("exec-%d", time.Now().UnixNano()) - process, err := task.Exec(ctx, execID, spec.Process, ioCreator) - if err != nil { - _ = stdinR.Close() - _ = stdinW.Close() - _ = stdoutR.Close() - _ = stdoutW.Close() - _ = stderrR.Close() - _ = stderrW.Close() - return nil, err - } - - if err := process.Start(ctx); err != nil { - _, _ = process.Delete(ctx) - _ = stdinR.Close() - _ = stdinW.Close() - _ = stdoutR.Close() - _ = stdoutW.Close() - _ = stderrR.Close() - _ = stderrW.Close() - return nil, err - } - - wait := func() (ExecTaskResult, error) { - statusC, err := process.Wait(ctx) - if err != nil { - return ExecTaskResult{}, err - } - status := <-statusC - code, _, err := status.Result() - if err != nil { - return ExecTaskResult{}, err - } - _, _ = process.Delete(ctx) - _ = stdoutW.Close() - _ = stderrW.Close() - return ExecTaskResult{ExitCode: code}, nil - } - - closeFn := func() error { - _ = stdinW.Close() - _ = stdoutR.Close() - _ = stderrR.Close() - _ = stdinR.Close() - _ = stdoutW.Close() - _ = stderrW.Close() - _, err := process.Delete(ctx) - return err - } - - return &ExecTaskSession{ - Stdin: stdinW, - Stdout: stdoutR, - Stderr: stderrR, - Wait: wait, - Close: closeFn, - }, nil -} - -func resolveExecFIFODir(preferred string) (string, error) { - candidates := make([]string, 0, 3) - if p := strings.TrimSpace(preferred); p != "" { - candidates = append(candidates, p) - } - candidates = append(candidates, "/var/lib/containerd/memoh-fifo", "/tmp/memoh-containerd-fifo") - - var lastErr error - for _, dir := range candidates { - if err := os.MkdirAll(dir, 0o755); err == nil { - return dir, nil - } else { - lastErr = err - } - } - if lastErr == nil { - lastErr = fmt.Errorf("no fifo directory candidate available") - } - return "", lastErr -} - func (s *DefaultService) ListContainersByLabel(ctx context.Context, key, value string) ([]ContainerInfo, error) { if key == "" { return nil, ErrInvalidArgument @@ -946,13 +717,17 @@ func (s *DefaultService) SnapshotMounts(ctx context.Context, snapshotter, key st return result, nil } -func (s *DefaultService) SetupNetwork(ctx context.Context, req NetworkSetupRequest) error { +func (s *DefaultService) SetupNetwork(ctx context.Context, req NetworkSetupRequest) (NetworkResult, error) { ctx = s.withNamespace(ctx) task, err := s.getTask(ctx, req.ContainerID) if err != nil { - return err + return NetworkResult{}, err } - return setupCNINetwork(ctx, task, req.ContainerID, req.CNIBinDir, req.CNIConfDir) + ip, err := setupCNINetwork(ctx, task, req.ContainerID, req.CNIBinDir, req.CNIConfDir) + if err != nil { + return NetworkResult{}, err + } + return NetworkResult{IP: ip}, nil } func (s *DefaultService) RemoveNetwork(ctx context.Context, req NetworkSetupRequest) error { diff --git a/internal/containerd/service_apple.go b/internal/containerd/service_apple.go index 282d3ca7..f5fc1823 100644 --- a/internal/containerd/service_apple.go +++ b/internal/containerd/service_apple.go @@ -3,7 +3,6 @@ package containerd import ( "context" "fmt" - "io" "log/slog" "os" "path/filepath" @@ -352,58 +351,13 @@ func (s *AppleService) ListTasks(ctx context.Context, opts *ListTasksOptions) ([ return out, nil } -// ExecTask executes a command inside the container. -// -// Limitations compared to the containerd backend: -// - WorkDir: the Apple Container exec API (ExecCreateRequest) has no working-directory -// field; req.WorkDir is silently ignored. -// - Env: similarly, environment variables cannot be injected at exec time via this -// API; req.Env is silently ignored. -// - Stdin: the acgo exec interface does not expose a write channel, so req.Stdin -// is not connected and the process receives no stdin input. -// - Stderr: the Apple Container API returns stdout and stderr as a single combined -// stream; they cannot be routed to separate writers. The combined output is sent -// to req.Stdout when set, otherwise to req.Stderr when set, otherwise discarded. -// - FIFODir: a containerd-specific concept; not applicable here. -func (s *AppleService) ExecTask(ctx context.Context, containerID string, req ExecTaskRequest) (ExecTaskResult, error) { - if containerID == "" || len(req.Args) == 0 { - return ExecTaskResult{}, ErrInvalidArgument - } - if err := s.ensureHealthy(ctx); err != nil { - return ExecTaskResult{}, err - } - ctr, err := s.client.LoadContainer(ctx, containerID) - if err != nil { - return ExecTaskResult{}, err - } - var execOpts []acgo.ExecOpt - if req.Terminal { - execOpts = append(execOpts, acgo.WithExecTTY()) - } - result, err := ctr.Exec(ctx, req.Args, execOpts...) - if err != nil { - return ExecTaskResult{}, err - } - if result.Output != nil { - dest := req.Stdout - if dest == nil { - dest = req.Stderr - } - _, _ = io.Copy(dest, result.Output) - _ = result.Output.Close() - } - return ExecTaskResult{ExitCode: 0}, nil -} - -func (s *AppleService) ExecTaskStreaming(context.Context, string, ExecTaskRequest) (*ExecTaskSession, error) { - return nil, ErrNotSupported -} - // --------------------------------------------------------------------------- // Network (no-op — Apple Container handles networking natively) // --------------------------------------------------------------------------- -func (s *AppleService) SetupNetwork(context.Context, NetworkSetupRequest) error { return nil } +func (s *AppleService) SetupNetwork(context.Context, NetworkSetupRequest) (NetworkResult, error) { + return NetworkResult{}, nil +} func (s *AppleService) RemoveNetwork(context.Context, NetworkSetupRequest) error { return nil } // --------------------------------------------------------------------------- diff --git a/internal/containerd/types.go b/internal/containerd/types.go index 6b1c3193..a5b60a2c 100644 --- a/internal/containerd/types.go +++ b/internal/containerd/types.go @@ -2,7 +2,6 @@ package containerd import ( "errors" - "io" "time" ) @@ -102,26 +101,7 @@ type NetworkSetupRequest struct { CNIConfDir string } -type ExecTaskRequest struct { - Args []string - Env []string - WorkDir string - Terminal bool - UseStdio bool - FIFODir string - Stdin io.Reader - Stdout io.Writer - Stderr io.Writer +type NetworkResult struct { + IP string } -type ExecTaskSession struct { - Stdin io.WriteCloser - Stdout io.ReadCloser - Stderr io.ReadCloser - Wait func() (ExecTaskResult, error) - Close func() error -} - -type ExecTaskResult struct { - ExitCode uint32 -} diff --git a/internal/db/sqlc/containers.sql.go b/internal/db/sqlc/containers.sql.go index bf3a7683..a7794fa1 100644 --- a/internal/db/sqlc/containers.sql.go +++ b/internal/db/sqlc/containers.sql.go @@ -21,7 +21,7 @@ func (q *Queries) DeleteContainerByBotID(ctx context.Context, botID pgtype.UUID) } const getContainerByBotID = `-- name: GetContainerByBotID :one -SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE bot_id = $1 ORDER BY updated_at DESC LIMIT 1 +SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE bot_id = $1 ORDER BY updated_at DESC LIMIT 1 ` func (q *Queries) GetContainerByBotID(ctx context.Context, botID pgtype.UUID) (Container, error) { @@ -36,7 +36,6 @@ func (q *Queries) GetContainerByBotID(ctx context.Context, botID pgtype.UUID) (C &i.Status, &i.Namespace, &i.AutoStart, - &i.HostPath, &i.ContainerPath, &i.CreatedAt, &i.UpdatedAt, @@ -47,7 +46,7 @@ func (q *Queries) GetContainerByBotID(ctx context.Context, botID pgtype.UUID) (C } const listAutoStartContainers = `-- name: ListAutoStartContainers :many -SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE auto_start = true ORDER BY updated_at DESC +SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE auto_start = true ORDER BY updated_at DESC ` func (q *Queries) ListAutoStartContainers(ctx context.Context) ([]Container, error) { @@ -68,7 +67,6 @@ func (q *Queries) ListAutoStartContainers(ctx context.Context) ([]Container, err &i.Status, &i.Namespace, &i.AutoStart, - &i.HostPath, &i.ContainerPath, &i.CreatedAt, &i.UpdatedAt, @@ -126,7 +124,7 @@ func (q *Queries) UpdateContainerStopped(ctx context.Context, botID pgtype.UUID) const upsertContainer = `-- name: UpsertContainer :exec INSERT INTO containers ( bot_id, container_id, container_name, image, status, namespace, auto_start, - host_path, container_path, last_started_at, last_stopped_at + container_path, last_started_at, last_stopped_at ) VALUES ( $1, @@ -138,8 +136,7 @@ VALUES ( $7, $8, $9, - $10, - $11 + $10 ) ON CONFLICT (container_id) DO UPDATE SET bot_id = EXCLUDED.bot_id, @@ -148,7 +145,6 @@ ON CONFLICT (container_id) DO UPDATE SET status = EXCLUDED.status, namespace = EXCLUDED.namespace, auto_start = EXCLUDED.auto_start, - host_path = EXCLUDED.host_path, container_path = EXCLUDED.container_path, last_started_at = EXCLUDED.last_started_at, last_stopped_at = EXCLUDED.last_stopped_at, @@ -163,7 +159,6 @@ type UpsertContainerParams struct { Status string `json:"status"` Namespace string `json:"namespace"` AutoStart bool `json:"auto_start"` - HostPath pgtype.Text `json:"host_path"` ContainerPath string `json:"container_path"` LastStartedAt pgtype.Timestamptz `json:"last_started_at"` LastStoppedAt pgtype.Timestamptz `json:"last_stopped_at"` @@ -178,7 +173,6 @@ func (q *Queries) UpsertContainer(ctx context.Context, arg UpsertContainerParams arg.Status, arg.Namespace, arg.AutoStart, - arg.HostPath, arg.ContainerPath, arg.LastStartedAt, arg.LastStoppedAt, diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index b35ce7ee..222611ee 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -185,7 +185,6 @@ type Container struct { Status string `json:"status"` Namespace string `json:"namespace"` AutoStart bool `json:"auto_start"` - HostPath pgtype.Text `json:"host_path"` ContainerPath string `json:"container_path"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` diff --git a/internal/fs/service.go b/internal/fs/service.go deleted file mode 100644 index 8e52161c..00000000 --- a/internal/fs/service.go +++ /dev/null @@ -1,520 +0,0 @@ -package fs - -import ( - "bytes" - "context" - "encoding/base64" - "errors" - "fmt" - "io" - "mime" - "net/http" - "os" - "path/filepath" - "strings" - "time" - - "github.com/containerd/containerd/v2/pkg/namespaces" - - "github.com/memohai/memoh/internal/config" - ctr "github.com/memohai/memoh/internal/containerd" - "github.com/memohai/memoh/internal/db" - dbsqlc "github.com/memohai/memoh/internal/db/sqlc" - memoryfmt "github.com/memohai/memoh/internal/memory" - "github.com/memohai/memoh/internal/mcp" -) - -type Error struct { - Code int - Message string - Err error -} - -func (e *Error) Error() string { - if strings.TrimSpace(e.Message) != "" { - return e.Message - } - if e.Err != nil { - return e.Err.Error() - } - return "fs operation failed" -} - -func (e *Error) Unwrap() error { return e.Err } - -func AsError(err error) (*Error, bool) { - var fsErr *Error - if errors.As(err, &fsErr) { - return fsErr, true - } - return nil, false -} - -type FileInfo struct { - Name string `json:"name"` - Path string `json:"path"` - Size int64 `json:"size"` - Mode string `json:"mode"` - ModTime string `json:"modTime"` - IsDir bool `json:"isDir"` -} - -type ListResult struct { - Path string `json:"path"` - Entries []FileInfo `json:"entries"` -} - -type ReadResult struct { - Path string `json:"path"` - Content string `json:"content"` - Size int64 `json:"size"` -} - -type DownloadResult struct { - FileName string - ContentType string - Data []byte - HostPath string - FromHost bool -} - -type UploadResult struct { - Path string `json:"path"` - Size int64 `json:"size"` -} - -type Service struct { - exec ctr.Service - queries *dbsqlc.Queries - namespace string - ensureBotDataRoot func(botID string) (string, error) -} - -func NewService(exec ctr.Service, queries *dbsqlc.Queries, namespace string, ensureBotDataRoot func(botID string) (string, error)) *Service { - return &Service{ - exec: exec, - queries: queries, - namespace: strings.TrimSpace(namespace), - ensureBotDataRoot: ensureBotDataRoot, - } -} - -type pathContext struct { - containerPath string - hostPath string - insideDataMount bool -} - -func (s *Service) Stat(ctx context.Context, botID, rawPath string) (FileInfo, error) { - if strings.TrimSpace(rawPath) == "" { - rawPath = "/" - } - pc, err := s.resolvePath(botID, rawPath) - if err != nil { - return FileInfo{}, err - } - if pc.insideDataMount { - info, osErr := os.Stat(pc.hostPath) - if osErr != nil { - if os.IsNotExist(osErr) { - return FileInfo{}, notFound("not found", osErr) - } - return FileInfo{}, internal(osErr.Error(), osErr) - } - return osFileInfoToFS(pc.containerPath, info), nil - } - out, err := s.execRead(ctx, botID, []string{"stat", "-c", `%n|%s|%a|%Y|%F`, pc.containerPath}) - if err != nil { - return FileInfo{}, internal(err.Error(), err) - } - fi, parseErr := parseStatLine(pc.containerPath, strings.TrimSpace(string(out))) - if parseErr != nil { - return FileInfo{}, internal(parseErr.Error(), parseErr) - } - return fi, nil -} - -func (s *Service) List(ctx context.Context, botID, rawPath string) (ListResult, error) { - if strings.TrimSpace(rawPath) == "" { - rawPath = "/" - } - pc, err := s.resolvePath(botID, rawPath) - if err != nil { - return ListResult{}, err - } - if pc.insideDataMount { - dirEntries, osErr := os.ReadDir(pc.hostPath) - if osErr != nil { - if os.IsNotExist(osErr) { - return ListResult{}, notFound("directory not found", osErr) - } - return ListResult{}, internal(osErr.Error(), osErr) - } - entries := make([]FileInfo, 0, len(dirEntries)) - for _, de := range dirEntries { - info, infoErr := de.Info() - if infoErr != nil { - continue - } - childPath := filepath.Join(pc.containerPath, de.Name()) - entries = append(entries, osFileInfoToFS(childPath, info)) - } - return ListResult{Path: pc.containerPath, Entries: entries}, nil - } - - out, err := s.execRead(ctx, botID, []string{"ls", "-1a", pc.containerPath}) - if err != nil { - return ListResult{}, internal(err.Error(), err) - } - lines := strings.Split(strings.TrimSpace(string(out)), "\n") - entries := make([]FileInfo, 0, len(lines)) - for _, name := range lines { - name = strings.TrimSpace(name) - if name == "" || name == "." || name == ".." { - continue - } - childPath := filepath.Join(pc.containerPath, name) - statOut, statErr := s.execRead(ctx, botID, []string{"stat", "-c", `%n|%s|%a|%Y|%F`, childPath}) - if statErr != nil { - entries = append(entries, FileInfo{Name: name, Path: childPath}) - continue - } - fi, parseErr := parseStatLine(childPath, strings.TrimSpace(string(statOut))) - if parseErr != nil { - entries = append(entries, FileInfo{Name: name, Path: childPath}) - continue - } - entries = append(entries, fi) - } - return ListResult{Path: pc.containerPath, Entries: entries}, nil -} - -func (s *Service) Read(ctx context.Context, botID, rawPath string) (ReadResult, error) { - result, err := s.ReadRaw(ctx, botID, rawPath) - if err != nil { - return ReadResult{}, err - } - result.Content = memoryfmt.RenderMemoryDayForDisplay(result.Path, result.Content) - result.Size = int64(len(result.Content)) - return result, nil -} - -func (s *Service) ReadRaw(ctx context.Context, botID, rawPath string) (ReadResult, error) { - if strings.TrimSpace(rawPath) == "" { - return ReadResult{}, badRequest("path is required", nil) - } - pc, err := s.resolvePath(botID, rawPath) - if err != nil { - return ReadResult{}, err - } - if pc.insideDataMount { - data, osErr := os.ReadFile(pc.hostPath) - if osErr != nil { - if os.IsNotExist(osErr) { - return ReadResult{}, notFound("file not found", osErr) - } - return ReadResult{}, internal(osErr.Error(), osErr) - } - return ReadResult{Path: pc.containerPath, Content: string(data), Size: int64(len(data))}, nil - } - out, err := s.execRead(ctx, botID, []string{"cat", pc.containerPath}) - if err != nil { - return ReadResult{}, internal(err.Error(), err) - } - return ReadResult{Path: pc.containerPath, Content: string(out), Size: int64(len(out))}, nil -} - -func (s *Service) Download(ctx context.Context, botID, rawPath string) (DownloadResult, error) { - if strings.TrimSpace(rawPath) == "" { - return DownloadResult{}, badRequest("path is required", nil) - } - pc, err := s.resolvePath(botID, rawPath) - if err != nil { - return DownloadResult{}, err - } - fileName := filepath.Base(pc.containerPath) - contentType := mime.TypeByExtension(filepath.Ext(fileName)) - if contentType == "" { - contentType = "application/octet-stream" - } - if pc.insideDataMount { - info, osErr := os.Stat(pc.hostPath) - if osErr != nil { - if os.IsNotExist(osErr) { - return DownloadResult{}, notFound("file not found", osErr) - } - return DownloadResult{}, internal(osErr.Error(), osErr) - } - if info.IsDir() { - return DownloadResult{}, badRequest("cannot download a directory", nil) - } - return DownloadResult{ - FileName: fileName, - ContentType: contentType, - HostPath: pc.hostPath, - FromHost: true, - }, nil - } - out, err := s.execRead(ctx, botID, []string{"base64", pc.containerPath}) - if err != nil { - return DownloadResult{}, internal(err.Error(), err) - } - decoded, decErr := base64.StdEncoding.DecodeString(strings.TrimSpace(string(out))) - if decErr != nil { - return DownloadResult{}, internal("failed to decode file content", decErr) - } - return DownloadResult{ - FileName: fileName, - ContentType: contentType, - Data: decoded, - }, nil -} - -func (s *Service) Write(botID, path, content string) error { - if strings.TrimSpace(path) == "" { - return badRequest("path is required", nil) - } - pc, err := s.resolvePath(botID, path) - if err != nil { - return err - } - if !pc.insideDataMount { - return forbidden("write operations are only allowed within the data directory", nil) - } - if err := os.MkdirAll(filepath.Dir(pc.hostPath), 0o755); err != nil { - return internal(err.Error(), err) - } - content = memoryfmt.NormalizeMemoryDayContent(pc.containerPath, content) - if err := os.WriteFile(pc.hostPath, []byte(content), 0o644); err != nil { - return internal(err.Error(), err) - } - return nil -} - -func (s *Service) Upload(botID, destPath string, src io.Reader) (UploadResult, error) { - if strings.TrimSpace(destPath) == "" { - return UploadResult{}, badRequest("path is required", nil) - } - pc, err := s.resolvePath(botID, destPath) - if err != nil { - return UploadResult{}, err - } - if !pc.insideDataMount { - return UploadResult{}, forbidden("upload operations are only allowed within the data directory", nil) - } - if err := os.MkdirAll(filepath.Dir(pc.hostPath), 0o755); err != nil { - return UploadResult{}, internal(err.Error(), err) - } - data, err := io.ReadAll(src) - if err != nil { - return UploadResult{}, internal(err.Error(), err) - } - data = []byte(memoryfmt.NormalizeMemoryDayContent(pc.containerPath, string(data))) - if err := os.WriteFile(pc.hostPath, data, 0o644); err != nil { - return UploadResult{}, internal(err.Error(), err) - } - return UploadResult{Path: pc.containerPath, Size: int64(len(data))}, nil -} - -func (s *Service) Mkdir(botID, path string) error { - if strings.TrimSpace(path) == "" { - return badRequest("path is required", nil) - } - pc, err := s.resolvePath(botID, path) - if err != nil { - return err - } - if !pc.insideDataMount { - return forbidden("mkdir operations are only allowed within the data directory", nil) - } - if err := os.MkdirAll(pc.hostPath, 0o755); err != nil { - return internal(err.Error(), err) - } - return nil -} - -func (s *Service) Delete(botID, path string, recursive bool) error { - if strings.TrimSpace(path) == "" { - return badRequest("path is required", nil) - } - pc, err := s.resolvePath(botID, path) - if err != nil { - return err - } - if !pc.insideDataMount { - return forbidden("delete operations are only allowed within the data directory", nil) - } - if filepath.Clean(pc.containerPath) == filepath.Clean(config.DefaultDataMount) { - return forbidden("cannot delete the data root directory", nil) - } - if _, statErr := os.Stat(pc.hostPath); os.IsNotExist(statErr) { - return notFound("not found", statErr) - } - if recursive { - if err := os.RemoveAll(pc.hostPath); err != nil { - return internal(err.Error(), err) - } - return nil - } - if err := os.Remove(pc.hostPath); err != nil { - return internal(err.Error(), err) - } - return nil -} - -func (s *Service) Rename(botID, oldPath, newPath string) error { - if strings.TrimSpace(oldPath) == "" || strings.TrimSpace(newPath) == "" { - return badRequest("oldPath and newPath are required", nil) - } - oldPC, err := s.resolvePath(botID, oldPath) - if err != nil { - return err - } - newPC, err := s.resolvePath(botID, newPath) - if err != nil { - return err - } - if !oldPC.insideDataMount || !newPC.insideDataMount { - return forbidden("rename operations are only allowed within the data directory", nil) - } - if _, statErr := os.Stat(oldPC.hostPath); os.IsNotExist(statErr) { - return notFound("source not found", statErr) - } - if err := os.MkdirAll(filepath.Dir(newPC.hostPath), 0o755); err != nil { - return internal(err.Error(), err) - } - if err := os.Rename(oldPC.hostPath, newPC.hostPath); err != nil { - return internal(err.Error(), err) - } - return nil -} - -func (s *Service) resolvePath(botID, rawPath string) (pathContext, error) { - containerPath := filepath.Clean("/" + strings.TrimSpace(rawPath)) - if containerPath == "" { - containerPath = "/" - } - dataMount := filepath.Clean(config.DefaultDataMount) - if containerPath == dataMount || strings.HasPrefix(containerPath, dataMount+"/") { - if s.ensureBotDataRoot == nil { - return pathContext{}, internal("bot data root resolver not configured", nil) - } - hostRoot, err := s.ensureBotDataRoot(botID) - if err != nil { - return pathContext{}, internal(err.Error(), err) - } - relPath := strings.TrimPrefix(containerPath, dataMount) - if relPath == "" { - relPath = "/" - } - hostPath := filepath.Clean(filepath.Join(hostRoot, filepath.FromSlash(relPath))) - if !strings.HasPrefix(hostPath, hostRoot) { - return pathContext{}, badRequest("path traversal detected", nil) - } - return pathContext{ - containerPath: containerPath, - hostPath: hostPath, - insideDataMount: true, - }, nil - } - return pathContext{containerPath: containerPath}, nil -} - -func (s *Service) resolveContainerID(ctx context.Context, botID string) string { - if s.queries != nil { - pgBotID, err := db.ParseUUID(botID) - if err == nil { - row, dbErr := s.queries.GetContainerByBotID(s.namespacedCtx(ctx), pgBotID) - if dbErr == nil && strings.TrimSpace(row.ContainerID) != "" { - return row.ContainerID - } - } - } - return mcp.ContainerPrefix + botID -} - -func (s *Service) namespacedCtx(ctx context.Context) context.Context { - if ctx == nil { - ctx = context.Background() - } - if s.namespace != "" { - return namespaces.WithNamespace(ctx, s.namespace) - } - return ctx -} - -func (s *Service) execRead(ctx context.Context, botID string, args []string) ([]byte, error) { - containerID := s.resolveContainerID(ctx, botID) - var stdout bytes.Buffer - var stderr bytes.Buffer - result, err := s.exec.ExecTask(s.namespacedCtx(ctx), containerID, ctr.ExecTaskRequest{ - Args: args, - Stdout: &stdout, - Stderr: &stderr, - }) - if err != nil { - return nil, fmt.Errorf("exec failed: %w", err) - } - if result.ExitCode != 0 { - errMsg := strings.TrimSpace(stderr.String()) - if errMsg == "" { - errMsg = fmt.Sprintf("exit code %d", result.ExitCode) - } - return nil, fmt.Errorf("command failed: %s", errMsg) - } - return stdout.Bytes(), nil -} - -func osFileInfoToFS(containerPath string, info os.FileInfo) FileInfo { - return FileInfo{ - Name: info.Name(), - Path: containerPath, - Size: info.Size(), - Mode: fmt.Sprintf("%04o", info.Mode().Perm()), - ModTime: info.ModTime().UTC().Format(time.RFC3339), - IsDir: info.IsDir(), - } -} - -func parseStatLine(containerPath, line string) (FileInfo, error) { - parts := strings.SplitN(line, "|", 5) - if len(parts) < 5 { - return FileInfo{}, fmt.Errorf("unexpected stat output: %s", line) - } - var size int64 - fmt.Sscanf(parts[1], "%d", &size) - mode := strings.TrimSpace(parts[2]) - var epoch int64 - fmt.Sscanf(parts[3], "%d", &epoch) - modTime := time.Unix(epoch, 0).UTC().Format(time.RFC3339) - fileType := strings.TrimSpace(parts[4]) - isDir := strings.Contains(fileType, "directory") - name := filepath.Base(containerPath) - if containerPath == "/" { - name = "/" - } - return FileInfo{ - Name: name, - Path: containerPath, - Size: size, - Mode: mode, - ModTime: modTime, - IsDir: isDir, - }, nil -} - -func badRequest(msg string, err error) error { - return &Error{Code: http.StatusBadRequest, Message: msg, Err: err} -} - -func forbidden(msg string, err error) error { - return &Error{Code: http.StatusForbidden, Message: msg, Err: err} -} - -func notFound(msg string, err error) error { - return &Error{Code: http.StatusNotFound, Message: msg, Err: err} -} - -func internal(msg string, err error) error { - return &Error{Code: http.StatusInternalServerError, Message: msg, Err: err} -} diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index f6726ace..e1337a3c 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -6,8 +6,6 @@ import ( "fmt" "log/slog" "net/http" - "os" - "path/filepath" "sort" "strings" "sync" @@ -16,7 +14,6 @@ import ( "github.com/containerd/errdefs" "github.com/google/uuid" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" "github.com/labstack/echo/v4" "github.com/memohai/memoh/internal/accounts" @@ -25,7 +22,6 @@ import ( ctr "github.com/memohai/memoh/internal/containerd" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" - fsops "github.com/memohai/memoh/internal/fs" "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/policy" ) @@ -46,7 +42,6 @@ type ContainerdHandler struct { accountService *accounts.Service policyService *policy.Service queries *dbsqlc.Queries - fsService *fsops.Service } type CreateContainerRequest struct { @@ -65,7 +60,6 @@ type GetContainerResponse struct { Image string `json:"image"` Status string `json:"status"` Namespace string `json:"namespace"` - HostPath string `json:"host_path,omitempty"` ContainerPath string `json:"container_path"` TaskRunning bool `json:"task_running"` CreatedAt time.Time `json:"created_at"` @@ -117,7 +111,6 @@ func NewContainerdHandler(log *slog.Logger, service ctr.Service, manager *mcp.Ma policyService: policyService, queries: queries, } - h.fsService = fsops.NewService(service, queries, namespace, h.ensureBotDataRoot) return h } @@ -149,10 +142,6 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { root.POST("/tools", h.HandleMCPTools) } -func (h *ContainerdHandler) FSService() *fsops.Service { - return h.fsService -} - // CreateContainer godoc // @Summary Create and start MCP container for bot // @Tags containerd @@ -181,96 +170,23 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } ctx := c.Request().Context() - dataRoot := strings.TrimSpace(h.cfg.DataRoot) - if dataRoot == "" { - dataRoot = config.DefaultDataRoot - } - dataRoot, err = filepath.Abs(dataRoot) - if err != nil { - h.logger.Warn("filepath.Abs failed", slog.Any("error", err)) - } - dataMount := config.DefaultDataMount - dataDir := filepath.Join(dataRoot, "bots", botID) - if err := os.MkdirAll(dataDir, 0o755); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if err := os.MkdirAll(filepath.Join(dataDir, ".skills"), 0o755); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - resolvPath, err := ctr.ResolveConfSource(dataDir) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - spec := h.buildMCPContainerSpec(dataDir, dataMount, resolvPath) - - _, err = h.service.CreateContainer(ctx, ctr.CreateContainerRequest{ - ID: containerID, - ImageRef: image, - Snapshotter: snapshotter, - Labels: map[string]string{ - mcp.BotLabelKey: botID, - }, - Spec: spec, - }) - if err != nil && !errdefs.IsAlreadyExists(err) { - return echo.NewHTTPError(http.StatusInternalServerError, "snapshotter="+snapshotter+" image="+image+" err="+err.Error()) - } - - if h.queries != nil { - pgBotID, parseErr := db.ParseUUID(botID) - if parseErr == nil { - ns := strings.TrimSpace(h.namespace) - if ns == "" { - ns = "default" - } - if dbErr := h.queries.UpsertContainer(c.Request().Context(), dbsqlc.UpsertContainerParams{ - BotID: pgBotID, - ContainerID: containerID, - ContainerName: containerID, - Image: image, - Status: "created", - Namespace: ns, - AutoStart: true, - HostPath: pgtype.Text{String: dataDir, Valid: true}, - ContainerPath: dataMount, - }); dbErr != nil { - h.logger.Error("failed to upsert container record", - slog.String("bot_id", botID), slog.Any("error", dbErr)) - } - } + if h.manager == nil { + return echo.NewHTTPError(http.StatusInternalServerError, "manager not configured") } started := false - if err := h.service.StartContainer(ctx, containerID, &ctr.StartTaskOptions{ - UseStdio: false, - }); err == nil { - started = true - if netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ - ContainerID: containerID, - CNIBinDir: h.cfg.CNIBinaryDir, - CNIConfDir: h.cfg.CNIConfigDir, - }); netErr != nil { - h.logger.Warn("mcp container network setup failed, task kept running", - slog.String("container_id", containerID), - slog.Any("error", netErr), - ) - } - if h.queries != nil { - if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { - if dbErr := h.queries.UpdateContainerStarted(c.Request().Context(), pgBotID); dbErr != nil { - h.logger.Error("failed to update container started status", - slog.String("bot_id", botID), slog.Any("error", dbErr)) - } - } - } - } else { + if err := h.manager.Start(ctx, botID); err != nil { h.logger.Error("mcp container start failed", slog.String("container_id", containerID), slog.Any("error", err), ) + } else { + started = true } + h.upsertContainerRecord(ctx, botID, containerID, map[bool]string{true: "running", false: "created"}[started]) + return c.JSON(http.StatusOK, CreateContainerResponse{ ContainerID: containerID, Image: image, @@ -303,33 +219,38 @@ func (h *ContainerdHandler) ensureContainerAndTask(ctx context.Context, containe } if len(tasks) > 0 { if tasks[0].Status == ctr.TaskStatusRunning { - if netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ + if netResult, netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ ContainerID: containerID, CNIBinDir: h.cfg.CNIBinaryDir, CNIConfDir: h.cfg.CNIConfigDir, }); netErr != nil { h.logger.Warn("network re-setup failed for running task", slog.String("container_id", containerID), slog.Any("error", netErr)) + } else if netResult.IP != "" && h.manager != nil { + h.manager.SetContainerIP(botID, netResult.IP) } return nil } if err := h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}); err != nil { - h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err)) + if !errdefs.IsNotFound(err) { + h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err)) + return err + } } } - if err := h.service.StartContainer(ctx, containerID, &ctr.StartTaskOptions{ - UseStdio: false, - }); err != nil { + if err := h.service.StartContainer(ctx, containerID, nil); err != nil { return err } - if netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ + if netResult, netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ ContainerID: containerID, CNIBinDir: h.cfg.CNIBinaryDir, CNIConfDir: h.cfg.CNIConfigDir, }); netErr != nil { h.logger.Warn("network setup failed, task kept running", slog.String("container_id", containerID), slog.Any("error", netErr)) + } else if netResult.IP != "" && h.manager != nil { + h.manager.SetContainerIP(botID, netResult.IP) } return nil } @@ -388,10 +309,6 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error { row, dbErr := h.queries.GetContainerByBotID(ctx, pgBotID) if dbErr == nil { taskRunning := h.isTaskRunning(ctx, row.ContainerID) - hostPath := "" - if row.HostPath.Valid { - hostPath = row.HostPath.String - } createdAt := time.Time{} if row.CreatedAt.Valid { createdAt = row.CreatedAt.Time @@ -405,7 +322,6 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error { Image: row.Image, Status: row.Status, Namespace: row.Namespace, - HostPath: hostPath, ContainerPath: row.ContainerPath, TaskRunning: taskRunning, CreatedAt: createdAt, @@ -772,99 +688,20 @@ func (h *ContainerdHandler) requireBotAccessWithGuest(c echo.Context) (string, e func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) error { containerID := mcp.ContainerPrefix + botID - image := h.mcpImageRef() - snapshotter := strings.TrimSpace(h.cfg.Snapshotter) - - dataRoot := strings.TrimSpace(h.cfg.DataRoot) - if dataRoot == "" { - dataRoot = config.DefaultDataRoot - } - if absRoot, absErr := filepath.Abs(dataRoot); absErr != nil { - h.logger.Warn("filepath.Abs failed", slog.Any("error", absErr)) - } else { - dataRoot = absRoot - } - dataMount := config.DefaultDataMount - dataDir := filepath.Join(dataRoot, "bots", botID) - if err := os.MkdirAll(dataDir, 0o755); err != nil { - return err - } - if err := os.MkdirAll(filepath.Join(dataDir, ".skills"), 0o755); err != nil { - return err - } - resolvPath, err := ctr.ResolveConfSource(dataDir) - if err != nil { - return err + if h.manager == nil { + return fmt.Errorf("manager not configured") } - spec := h.buildMCPContainerSpec(dataDir, dataMount, resolvPath) - - _, err = h.service.CreateContainer(ctx, ctr.CreateContainerRequest{ - ID: containerID, - ImageRef: image, - Snapshotter: snapshotter, - Labels: map[string]string{ - mcp.BotLabelKey: botID, - }, - Spec: spec, - }) - if err != nil && !errdefs.IsAlreadyExists(err) { - return err - } - - if h.queries != nil { - pgBotID, parseErr := db.ParseUUID(botID) - if parseErr == nil { - ns := strings.TrimSpace(h.namespace) - if ns == "" { - ns = "default" - } - if dbErr := h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{ - BotID: pgBotID, - ContainerID: containerID, - ContainerName: containerID, - Image: image, - Status: "created", - Namespace: ns, - AutoStart: true, - HostPath: pgtype.Text{String: dataDir, Valid: true}, - ContainerPath: dataMount, - }); dbErr != nil { - h.logger.Error("setup bot container: failed to upsert container record", - slog.String("bot_id", botID), slog.Any("error", dbErr)) - } - } - } - - if err := h.service.StartContainer(ctx, containerID, &ctr.StartTaskOptions{ - UseStdio: false, - }); err == nil { - if netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ - ContainerID: containerID, - CNIBinDir: h.cfg.CNIBinaryDir, - CNIConfDir: h.cfg.CNIConfigDir, - }); netErr != nil { - h.logger.Warn("setup bot container: network setup failed, task kept running", - slog.String("bot_id", botID), - slog.String("container_id", containerID), - slog.Any("error", netErr), - ) - } - if h.queries != nil { - if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil { - if dbErr := h.queries.UpdateContainerStarted(ctx, pgBotID); dbErr != nil { - h.logger.Error("setup bot container: failed to update container started status", - slog.String("bot_id", botID), slog.Any("error", dbErr)) - } - } - } - } else { - h.logger.Error("setup bot container: task start failed", + if err := h.manager.Start(ctx, botID); err != nil { + h.logger.Error("setup bot container: start failed", slog.String("bot_id", botID), slog.String("container_id", containerID), slog.Any("error", err), ) + return err } + + h.upsertContainerRecord(ctx, botID, containerID, "running") return nil } @@ -1000,7 +837,7 @@ func (h *ContainerdHandler) ReconcileContainers(ctx context.Context) { slog.String("bot_id", botID), slog.Any("error", dbErr)) } } - if netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ + if netResult, netErr := h.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ ContainerID: containerID, CNIBinDir: h.cfg.CNIBinaryDir, CNIConfDir: h.cfg.CNIConfigDir, @@ -1009,6 +846,8 @@ func (h *ContainerdHandler) ReconcileContainers(ctx context.Context) { slog.String("bot_id", botID), slog.String("container_id", containerID), slog.Any("error", netErr)) + } else if netResult.IP != "" && h.manager != nil { + h.manager.SetContainerIP(botID, netResult.IP) } h.logger.Info("reconcile: container healthy", slog.String("bot_id", botID), slog.String("container_id", containerID)) @@ -1035,50 +874,34 @@ func (h *ContainerdHandler) ReconcileContainers(ctx context.Context) { h.logger.Info("reconcile: completed") } -func (h *ContainerdHandler) buildMCPContainerSpec(dataDir, dataMount, resolvPath string) ctr.ContainerSpec { - mounts := []ctr.MountSpec{ - { - Destination: dataMount, - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, - { - Destination: "/etc/resolv.conf", - Type: "bind", - Source: resolvPath, - Options: []string{"rbind", "ro"}, - }, +func (h *ContainerdHandler) upsertContainerRecord(ctx context.Context, botID, containerID, status string) { + if h.queries == nil { + return } - tzMounts, tzEnv := ctr.TimezoneSpec() - mounts = append(mounts, tzMounts...) - - bootScript := fmt.Sprintf( - "[ -e /app/mcp ] || { mkdir -p /app; cp -a /opt/mcp /app/mcp 2>/dev/null; true; }; "+ - "cp -an /opt/mcp-template/* %s/ 2>/dev/null; true; "+ - "exec /app/mcp", - dataMount, - ) - - return ctr.ContainerSpec{ - Cmd: []string{"/bin/sh", "-c", bootScript}, - Env: tzEnv, - Mounts: mounts, - } -} - -func (h *ContainerdHandler) ensureBotDataRoot(botID string) (string, error) { - dataRoot := strings.TrimSpace(h.cfg.DataRoot) - if dataRoot == "" { - dataRoot = config.DefaultDataRoot - } - dataRoot, err := filepath.Abs(dataRoot) + pgBotID, err := db.ParseUUID(botID) if err != nil { - return "", err + return } - root := filepath.Join(dataRoot, "bots", botID) - if err := os.MkdirAll(root, 0o755); err != nil { - return "", err + ns := strings.TrimSpace(h.namespace) + if ns == "" { + ns = "default" + } + if dbErr := h.queries.UpsertContainer(ctx, dbsqlc.UpsertContainerParams{ + BotID: pgBotID, + ContainerID: containerID, + ContainerName: containerID, + Image: h.mcpImageRef(), + Status: status, + Namespace: ns, + AutoStart: true, + }); dbErr != nil { + h.logger.Error("failed to upsert container record", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } + if status == "running" { + if dbErr := h.queries.UpdateContainerStarted(ctx, pgBotID); dbErr != nil { + h.logger.Error("failed to update container started status", + slog.String("bot_id", botID), slog.Any("error", dbErr)) + } } - return root, nil } diff --git a/internal/handlers/filemanager.go b/internal/handlers/filemanager.go index a0f0ccc6..87acf5a3 100644 --- a/internal/handlers/filemanager.go +++ b/internal/handlers/filemanager.go @@ -1,21 +1,46 @@ package handlers import ( + "context" + "errors" "fmt" + "io" + "mime" "net/http" + "path/filepath" "strings" "github.com/labstack/echo/v4" - fsops "github.com/memohai/memoh/internal/fs" + "github.com/memohai/memoh/internal/mcp/mcpclient" ) // ---------- request / response types ---------- -type FSFileInfo = fsops.FileInfo -type FSListResponse = fsops.ListResult -type FSReadResponse = fsops.ReadResult -type FSUploadResponse = fsops.UploadResult +type FSFileInfo struct { + Name string `json:"name"` + Path string `json:"path"` + Size int64 `json:"size"` + Mode string `json:"mode"` + ModTime string `json:"modTime"` + IsDir bool `json:"isDir"` +} + +type FSListResponse struct { + Path string `json:"path"` + Entries []FSFileInfo `json:"entries"` +} + +type FSReadResponse struct { + Path string `json:"path"` + Content string `json:"content"` + Size int64 `json:"size"` +} + +type FSUploadResponse struct { + Path string `json:"path"` + Size int64 `json:"size"` +} // FSWriteRequest is the body for creating / overwriting a file. type FSWriteRequest struct { @@ -44,6 +69,53 @@ type fsOpResponse struct { OK bool `json:"ok"` } +// ---------- helpers ---------- + +// resolveContainerPath cleans and validates a container-relative path. +func resolveContainerPath(rawPath string) (string, error) { + cleaned := filepath.Clean("/" + strings.TrimSpace(rawPath)) + if cleaned == "" { + cleaned = "/" + } + if strings.HasPrefix(cleaned, "..") { + return "", fmt.Errorf("invalid path") + } + return cleaned, nil +} + +// getGRPCClient returns the gRPC client for the bot's container. +func (h *ContainerdHandler) getGRPCClient(ctx context.Context, botID string) (*mcpclient.Client, error) { + return h.manager.MCPClient(ctx, botID) +} + +// fsFileInfoFromEntry converts a gRPC FileEntry to FSFileInfo. +func fsFileInfoFromEntry(containerPath, name string, isDir bool, size int64, mode, modTime string) FSFileInfo { + return FSFileInfo{ + Name: name, + Path: filepath.Join(containerPath, name), + Size: size, + Mode: mode, + ModTime: modTime, + IsDir: isDir, + } +} + +// fsHTTPError maps mcpclient domain errors to HTTP status codes. +func fsHTTPError(err error) *echo.HTTPError { + switch { + case errors.Is(err, mcpclient.ErrNotFound): + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + case errors.Is(err, mcpclient.ErrBadRequest): + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + case errors.Is(err, mcpclient.ErrForbidden): + return echo.NewHTTPError(http.StatusForbidden, err.Error()) + case errors.Is(err, mcpclient.ErrUnavailable): + return echo.NewHTTPError(http.StatusServiceUnavailable, err.Error()) + default: + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } +} + // ---------- handlers ---------- // FSStat godoc @@ -63,11 +135,35 @@ func (h *ContainerdHandler) FSStat(c echo.Context) error { if err != nil { return err } - fi, err := h.fsService.Stat(c.Request().Context(), botID, c.QueryParam("path")) - if err != nil { - return h.toFSHTTPError(err) + rawPath := c.QueryParam("path") + if strings.TrimSpace(rawPath) == "" { + rawPath = "/" } - return c.JSON(http.StatusOK, fi) + + containerPath, err := resolveContainerPath(rawPath) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + + entry, err := client.Stat(ctx, containerPath) + if err != nil { + return fsHTTPError(err) + } + + return c.JSON(http.StatusOK, FSFileInfo{ + Name: filepath.Base(containerPath), + Path: containerPath, + Size: entry.GetSize(), + Mode: entry.GetMode(), + ModTime: entry.GetModTime(), + IsDir: entry.GetIsDir(), + }) } // FSList godoc @@ -86,11 +182,46 @@ func (h *ContainerdHandler) FSList(c echo.Context) error { if err != nil { return err } - resp, err := h.fsService.List(c.Request().Context(), botID, c.QueryParam("path")) - if err != nil { - return h.toFSHTTPError(err) + rawPath := c.QueryParam("path") + if strings.TrimSpace(rawPath) == "" { + rawPath = "/" } - return c.JSON(http.StatusOK, resp) + + containerPath, err := resolveContainerPath(rawPath) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + + entries, err := client.ListDir(ctx, containerPath, false) + if err != nil { + return fsHTTPError(err) + } + + fileInfos := make([]FSFileInfo, 0, len(entries)) + for _, e := range entries { + if e.Path == containerPath { + continue + } + fileInfos = append(fileInfos, fsFileInfoFromEntry( + containerPath, + filepath.Base(e.Path), + e.IsDir, + e.Size, + e.Mode, + e.ModTime, + )) + } + + return c.JSON(http.StatusOK, FSListResponse{ + Path: containerPath, + Entries: fileInfos, + }) } // FSRead godoc @@ -109,11 +240,32 @@ func (h *ContainerdHandler) FSRead(c echo.Context) error { if err != nil { return err } - resp, err := h.fsService.Read(c.Request().Context(), botID, c.QueryParam("path")) - if err != nil { - return h.toFSHTTPError(err) + rawPath := c.QueryParam("path") + if strings.TrimSpace(rawPath) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "path is required") } - return c.JSON(http.StatusOK, resp) + + containerPath, err := resolveContainerPath(rawPath) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + + resp, err := client.ReadFile(ctx, containerPath, 0, 0) + if err != nil { + return fsHTTPError(err) + } + + return c.JSON(http.StatusOK, FSReadResponse{ + Path: containerPath, + Content: resp.GetContent(), + Size: int64(len(resp.GetContent())), + }) } // FSDownload godoc @@ -133,15 +285,41 @@ func (h *ContainerdHandler) FSDownload(c echo.Context) error { if err != nil { return err } - resp, err := h.fsService.Download(c.Request().Context(), botID, c.QueryParam("path")) + rawPath := c.QueryParam("path") + if strings.TrimSpace(rawPath) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "path is required") + } + + containerPath, err := resolveContainerPath(rawPath) if err != nil { - return h.toFSHTTPError(err) + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, resp.FileName)) - if resp.FromHost { - return c.File(resp.HostPath) + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) } - return c.Blob(http.StatusOK, resp.ContentType, resp.Data) + + rc, err := client.ReadRaw(ctx, containerPath) + if err != nil { + return fsHTTPError(err) + } + defer rc.Close() + + data, err := io.ReadAll(rc) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "failed to read file") + } + + fileName := filepath.Base(containerPath) + contentType := mime.TypeByExtension(filepath.Ext(fileName)) + if contentType == "" { + contentType = "application/octet-stream" + } + + c.Response().Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, fileName)) + return c.Blob(http.StatusOK, contentType, data) } // FSWrite godoc @@ -164,9 +342,25 @@ func (h *ContainerdHandler) FSWrite(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.fsService.Write(botID, req.Path, req.Content); err != nil { - return h.toFSHTTPError(err) + if strings.TrimSpace(req.Path) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "path is required") } + + containerPath, err := resolveContainerPath(req.Path) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + + if err := client.WriteFile(ctx, containerPath, []byte(req.Content)); err != nil { + return fsHTTPError(err) + } + return c.JSON(http.StatusOK, fsOpResponse{OK: true}) } @@ -192,6 +386,18 @@ func (h *ContainerdHandler) FSUpload(c echo.Context) error { if destPath == "" { return echo.NewHTTPError(http.StatusBadRequest, "path is required") } + + containerPath, err := resolveContainerPath(destPath) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + file, err := c.FormFile("file") if err != nil { return echo.NewHTTPError(http.StatusBadRequest, "file is required") @@ -201,11 +407,16 @@ func (h *ContainerdHandler) FSUpload(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } defer src.Close() - resp, err := h.fsService.Upload(botID, destPath, src) + + written, err := client.WriteRaw(ctx, containerPath, src) if err != nil { - return h.toFSHTTPError(err) + return fsHTTPError(err) } - return c.JSON(http.StatusOK, resp) + + return c.JSON(http.StatusOK, FSUploadResponse{ + Path: containerPath, + Size: written, + }) } // FSMkdir godoc @@ -228,9 +439,25 @@ func (h *ContainerdHandler) FSMkdir(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.fsService.Mkdir(botID, req.Path); err != nil { - return h.toFSHTTPError(err) + if strings.TrimSpace(req.Path) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "path is required") } + + containerPath, err := resolveContainerPath(req.Path) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + + if err := client.Mkdir(ctx, containerPath); err != nil { + return fsHTTPError(err) + } + return c.JSON(http.StatusOK, fsOpResponse{OK: true}) } @@ -255,9 +482,29 @@ func (h *ContainerdHandler) FSDelete(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.fsService.Delete(botID, req.Path, req.Recursive); err != nil { - return h.toFSHTTPError(err) + if strings.TrimSpace(req.Path) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "path is required") } + + containerPath, err := resolveContainerPath(req.Path) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + if containerPath == "/" { + return echo.NewHTTPError(http.StatusForbidden, "cannot delete root directory") + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + + if err := client.DeleteFile(ctx, containerPath, req.Recursive); err != nil { + return fsHTTPError(err) + } + return c.JSON(http.StatusOK, fsOpResponse{OK: true}) } @@ -282,15 +529,28 @@ func (h *ContainerdHandler) FSRename(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if err := h.fsService.Rename(botID, req.OldPath, req.NewPath); err != nil { - return h.toFSHTTPError(err) + if strings.TrimSpace(req.OldPath) == "" || strings.TrimSpace(req.NewPath) == "" { + return echo.NewHTTPError(http.StatusBadRequest, "oldPath and newPath are required") } + + oldPath, err := resolveContainerPath(req.OldPath) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + newPath, err := resolveContainerPath(req.NewPath) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusServiceUnavailable, fmt.Sprintf("container not reachable: %v", err)) + } + + if err := client.Rename(ctx, oldPath, newPath); err != nil { + return fsHTTPError(err) + } + return c.JSON(http.StatusOK, fsOpResponse{OK: true}) } - -func (h *ContainerdHandler) toFSHTTPError(err error) error { - if fsErr, ok := fsops.AsError(err); ok { - return echo.NewHTTPError(fsErr.Code, fsErr.Message) - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) -} diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go deleted file mode 100644 index b7cff5db..00000000 --- a/internal/handlers/fs.go +++ /dev/null @@ -1,641 +0,0 @@ -package handlers - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "os/exec" - "path/filepath" - "strings" - "sync" - - "github.com/containerd/errdefs" - "github.com/labstack/echo/v4" - sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" - sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" - - ctr "github.com/memohai/memoh/internal/containerd" - mcptools "github.com/memohai/memoh/internal/mcp" -) - -func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerID, botID string) error { - if strings.TrimSpace(botID) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "bot id is required") - } - info, err := h.service.GetContainer(ctx, containerID) - if err != nil { - if errdefs.IsNotFound(err) { - return echo.NewHTTPError(http.StatusNotFound, "container not found") - } - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - labelBotID := strings.TrimSpace(info.Labels[mcptools.BotLabelKey]) - if labelBotID != "" && labelBotID != botID { - return echo.NewHTTPError(http.StatusForbidden, "bot mismatch") - } - return nil -} - -type mcpSession struct { - stdin io.WriteCloser - stdout io.ReadCloser - stderr io.ReadCloser - cmd *exec.Cmd - initMu sync.Mutex - initState mcpSessionInitState - initWait chan struct{} - pendingMu sync.Mutex - pending map[string]chan *sdkjsonrpc.Response - conn sdkmcp.Connection - closed chan struct{} - closeOnce sync.Once - closeErr error - onClose func() -} - -type mcpSessionInitState uint8 - -const ( - mcpSessionInitStateNone mcpSessionInitState = iota - mcpSessionInitStateInitializing - mcpSessionInitStateInitialized - mcpSessionInitStateReady -) - -func (h *ContainerdHandler) getMCPSession(ctx context.Context, containerID string) (*mcpSession, error) { - h.mcpMu.Lock() - if sess, ok := h.mcpSess[containerID]; ok { - h.mcpMu.Unlock() - return sess, nil - } - h.mcpMu.Unlock() - - sess, err := h.startContainerdMCPSession(ctx, containerID) - if err != nil { - return nil, err - } - - h.mcpMu.Lock() - h.mcpSess[containerID] = sess - h.mcpMu.Unlock() - - sess.onClose = func() { - h.mcpMu.Lock() - if current, ok := h.mcpSess[containerID]; ok && current == sess { - delete(h.mcpSess, containerID) - } - h.mcpMu.Unlock() - } - - return sess, nil -} - -func (h *ContainerdHandler) startContainerdMCPSession(ctx context.Context, containerID string) (*mcpSession, error) { - execSession, err := h.service.ExecTaskStreaming(ctx, containerID, ctr.ExecTaskRequest{ - Args: []string{"/app/mcp"}, - FIFODir: h.mcpFIFODir(), - }) - if err != nil { - return nil, err - } - - sess := &mcpSession{ - stdin: execSession.Stdin, - stdout: execSession.Stdout, - stderr: execSession.Stderr, - pending: make(map[string]chan *sdkjsonrpc.Response), - closed: make(chan struct{}), - } - transport := &sdkmcp.IOTransport{ - Reader: sess.stdout, - Writer: sess.stdin, - } - conn, err := transport.Connect(ctx) - if err != nil { - sess.closeWithError(err) - return nil, err - } - sess.conn = conn - - h.startMCPStderrLogger(execSession.Stderr, containerID) - go sess.readLoop() - go func() { - _, err := execSession.Wait() - if err != nil { - if isBenignMCPSessionExit(err) { - sess.closeWithError(io.EOF) - return - } - h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID)) - sess.closeWithError(err) - return - } - sess.closeWithError(io.EOF) - }() - - return sess, nil -} - -func (s *mcpSession) closeWithError(err error) { - s.closeOnce.Do(func() { - s.closeErr = err - close(s.closed) - s.pendingMu.Lock() - for _, ch := range s.pending { - close(ch) - } - s.pending = map[string]chan *sdkjsonrpc.Response{} - s.pendingMu.Unlock() - if s.conn != nil { - _ = s.conn.Close() - } - if s.stdin != nil { - _ = s.stdin.Close() - } - if s.stdout != nil { - _ = s.stdout.Close() - } - if s.stderr != nil { - _ = s.stderr.Close() - } - if s.cmd != nil && s.cmd.Process != nil { - _ = s.cmd.Process.Kill() - } - if s.onClose != nil { - s.onClose() - } - }) -} - -func (h *ContainerdHandler) startMCPStderrLogger(stderr io.ReadCloser, containerID string) { - if stderr == nil { - return - } - go func() { - scanner := bufio.NewScanner(stderr) - scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - h.logger.Warn("mcp stderr", slog.String("container_id", containerID), slog.String("message", line)) - } - if err := scanner.Err(); err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "closed pipe") { - return - } - h.logger.Error("mcp stderr read failed", slog.Any("error", err), slog.String("container_id", containerID)) - } - }() -} - -func isBenignMCPSessionExit(err error) bool { - if err == nil { - return false - } - if errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { - return true - } - msg := strings.ToLower(err.Error()) - return strings.Contains(msg, "code = canceled") || strings.Contains(msg, "context canceled") || strings.Contains(msg, "closed pipe") -} - -func (h *ContainerdHandler) mcpFIFODir() string { - if root := strings.TrimSpace(h.cfg.DataRoot); root != "" { - return filepath.Join(root, ".containerd-fifo") - } - return "/tmp/memoh-containerd-fifo" -} - -func (s *mcpSession) readLoop() { - if s.conn == nil { - s.closeWithError(io.EOF) - return - } - for { - msg, err := s.conn.Read(context.Background()) - if err != nil { - if errors.Is(err, io.EOF) { - s.closeWithError(io.EOF) - return - } - s.closeWithError(err) - return - } - resp, ok := msg.(*sdkjsonrpc.Response) - if !ok || !resp.ID.IsValid() { - continue - } - id := sdkIDKey(resp.ID) - if id == "" { - continue - } - s.pendingMu.Lock() - ch, ok := s.pending[id] - if ok { - delete(s.pending, id) - } - s.pendingMu.Unlock() - if ok { - ch <- resp - close(ch) - } - } -} - -func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { - method := strings.TrimSpace(req.Method) - if method == "initialize" { - return s.callInitialize(ctx, req) - } - if method != "notifications/initialized" { - if err := s.ensureInitialized(ctx); err != nil { - return nil, err - } - } - - targetID, err := parseRawJSONRPCID(req.ID) - if err != nil { - return nil, err - } - target := sdkIDKey(targetID) - if target == "" { - return nil, fmt.Errorf("missing request id") - } - if s.conn == nil { - return nil, io.EOF - } - - respCh := make(chan *sdkjsonrpc.Response, 1) - s.pendingMu.Lock() - s.pending[target] = respCh - s.pendingMu.Unlock() - - callReq := &sdkjsonrpc.Request{ - ID: targetID, - Method: method, - Params: req.Params, - } - if err := s.conn.Write(ctx, callReq); err != nil { - s.removePending(target) - return nil, err - } - - select { - case resp, ok := <-respCh: - if !ok { - if s.closeErr != nil { - return nil, s.closeErr - } - return nil, io.EOF - } - if method == "notifications/initialized" { - s.setInitStateAtLeast(mcpSessionInitStateReady) - } - return sdkResponsePayload(resp) - case <-s.closed: - if s.closeErr != nil { - return nil, s.closeErr - } - return nil, io.EOF - case <-ctx.Done(): - s.removePending(target) - return nil, ctx.Err() - } -} - -func (s *mcpSession) callInitialize(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { - payload, err := s.callRaw(ctx, req) - if err != nil { - return nil, err - } - if err := mcptools.PayloadError(payload); err != nil { - return payload, nil - } - s.setInitStateAtLeast(mcpSessionInitStateInitialized) - return payload, nil -} - -func (s *mcpSession) callRaw(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { - method := strings.TrimSpace(req.Method) - targetID, err := parseRawJSONRPCID(req.ID) - if err != nil { - return nil, err - } - target := sdkIDKey(targetID) - if target == "" { - return nil, fmt.Errorf("missing request id") - } - if s.conn == nil { - return nil, io.EOF - } - - respCh := make(chan *sdkjsonrpc.Response, 1) - s.pendingMu.Lock() - s.pending[target] = respCh - s.pendingMu.Unlock() - - callReq := &sdkjsonrpc.Request{ - ID: targetID, - Method: method, - Params: req.Params, - } - if err := s.conn.Write(ctx, callReq); err != nil { - s.removePending(target) - return nil, err - } - - select { - case resp, ok := <-respCh: - if !ok { - if s.closeErr != nil { - return nil, s.closeErr - } - return nil, io.EOF - } - return sdkResponsePayload(resp) - case <-s.closed: - if s.closeErr != nil { - return nil, s.closeErr - } - return nil, io.EOF - case <-ctx.Done(): - s.removePending(target) - return nil, ctx.Err() - } -} - -func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error { - if s.conn == nil { - return io.EOF - } - method := strings.TrimSpace(req.Method) - notification := &sdkjsonrpc.Request{ - Method: method, - Params: req.Params, - } - if err := s.conn.Write(ctx, notification); err != nil { - return err - } - if method == "notifications/initialized" { - s.setInitStateAtLeast(mcpSessionInitStateReady) - } - return nil -} - -func (s *mcpSession) ensureInitialized(ctx context.Context) error { - for { - s.initMu.Lock() - switch s.initState { - case mcpSessionInitStateReady: - s.initMu.Unlock() - return nil - case mcpSessionInitStateInitializing: - waitCh := s.initWait - s.initMu.Unlock() - if waitCh == nil { - continue - } - select { - case <-waitCh: - continue - case <-ctx.Done(): - return ctx.Err() - case <-s.closed: - if s.closeErr != nil { - return s.closeErr - } - return io.EOF - } - case mcpSessionInitStateInitialized: - waitCh := make(chan struct{}) - s.initState = mcpSessionInitStateInitializing - s.initWait = waitCh - s.initMu.Unlock() - - err := s.sendInitializedNotification(ctx) - - s.initMu.Lock() - if err == nil { - s.initState = mcpSessionInitStateReady - } else { - s.initState = mcpSessionInitStateInitialized - } - s.initWait = nil - close(waitCh) - s.initMu.Unlock() - - if err != nil { - return err - } - return nil - default: - waitCh := make(chan struct{}) - s.initState = mcpSessionInitStateInitializing - s.initWait = waitCh - s.initMu.Unlock() - - nextState, err := s.initializeHandshake(ctx) - - s.initMu.Lock() - s.initState = nextState - s.initWait = nil - close(waitCh) - s.initMu.Unlock() - - if err != nil { - return err - } - if nextState == mcpSessionInitStateReady { - return nil - } - } - } -} - -func (s *mcpSession) initializeHandshake(ctx context.Context) (mcpSessionInitState, error) { - params, err := json.Marshal(map[string]any{ - "protocolVersion": "2025-06-18", - "capabilities": map[string]any{ - "roots": map[string]any{ - "listChanged": false, - }, - }, - "clientInfo": map[string]any{ - "name": "memoh-http-proxy", - "version": "v0", - }, - }) - if err != nil { - return mcpSessionInitStateNone, err - } - initID, err := sdkjsonrpc.MakeID("init-1") - if err != nil { - return mcpSessionInitStateNone, err - } - initResp, err := s.invokeCall(ctx, &sdkjsonrpc.Request{ - ID: initID, - Method: "initialize", - Params: params, - }) - if err != nil { - return mcpSessionInitStateNone, err - } - if initResp.Error != nil { - return mcpSessionInitStateNone, initResp.Error - } - if err := s.sendInitializedNotification(ctx); err != nil { - return mcpSessionInitStateInitialized, err - } - return mcpSessionInitStateReady, nil -} - -func (s *mcpSession) sendInitializedNotification(ctx context.Context) error { - if s.conn == nil { - return io.EOF - } - return s.conn.Write(ctx, &sdkjsonrpc.Request{ - Method: "notifications/initialized", - }) -} - -func (s *mcpSession) invokeCall(ctx context.Context, req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { - if s.conn == nil { - return nil, io.EOF - } - if req == nil || !req.ID.IsValid() { - return nil, fmt.Errorf("missing request id") - } - key := sdkIDKey(req.ID) - if key == "" { - return nil, fmt.Errorf("invalid request id") - } - - respCh := make(chan *sdkjsonrpc.Response, 1) - s.pendingMu.Lock() - s.pending[key] = respCh - s.pendingMu.Unlock() - - if err := s.conn.Write(ctx, req); err != nil { - s.removePending(key) - return nil, err - } - - select { - case resp, ok := <-respCh: - if !ok { - if s.closeErr != nil { - return nil, s.closeErr - } - return nil, io.EOF - } - return resp, nil - case <-s.closed: - if s.closeErr != nil { - return nil, s.closeErr - } - return nil, io.EOF - case <-ctx.Done(): - s.removePending(key) - return nil, ctx.Err() - } -} - -func (s *mcpSession) removePending(key string) { - if strings.TrimSpace(key) == "" { - return - } - s.pendingMu.Lock() - delete(s.pending, key) - s.pendingMu.Unlock() -} - -func (s *mcpSession) setInitStateAtLeast(next mcpSessionInitState) { - s.initMu.Lock() - if s.initState != mcpSessionInitStateInitializing && s.initState < next { - s.initState = next - } - s.initMu.Unlock() -} - -func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { - if len(raw) == 0 { - return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") - } - var idValue any - if err := json.Unmarshal(raw, &idValue); err != nil { - return sdkjsonrpc.ID{}, err - } - id, err := sdkjsonrpc.MakeID(idValue) - if err != nil { - return sdkjsonrpc.ID{}, err - } - if !id.IsValid() { - return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") - } - return id, nil -} - -func sdkIDKey(id sdkjsonrpc.ID) string { - if !id.IsValid() { - return "" - } - raw, err := json.Marshal(id.Raw()) - if err != nil { - return "" - } - return string(raw) -} - -func sdkIDRaw(id sdkjsonrpc.ID) json.RawMessage { - if !id.IsValid() { - return nil - } - raw, err := json.Marshal(id.Raw()) - if err != nil { - return nil - } - return json.RawMessage(raw) -} - -func sdkResponsePayload(resp *sdkjsonrpc.Response) (map[string]any, error) { - if resp == nil { - return nil, io.EOF - } - if resp.Error != nil { - code := int64(-32603) - message := strings.TrimSpace(resp.Error.Error()) - if wireErr, ok := resp.Error.(*sdkjsonrpc.Error); ok { - code = wireErr.Code - message = strings.TrimSpace(wireErr.Message) - } - if message == "" { - message = "internal error" - } - return map[string]any{ - "jsonrpc": "2.0", - "id": sdkIDRaw(resp.ID), - "error": map[string]any{ - "code": code, - "message": message, - }, - }, nil - } - var result any - if len(resp.Result) > 0 { - if err := json.Unmarshal(resp.Result, &result); err != nil { - return nil, err - } - } - return map[string]any{ - "jsonrpc": "2.0", - "id": sdkIDRaw(resp.ID), - "result": result, - }, nil -} diff --git a/internal/handlers/fs_mcp_session_test.go b/internal/handlers/fs_mcp_session_test.go deleted file mode 100644 index 3ef000ca..00000000 --- a/internal/handlers/fs_mcp_session_test.go +++ /dev/null @@ -1,255 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "fmt" - "io" - "sync" - "testing" - "time" - - sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" - - mcptools "github.com/memohai/memoh/internal/mcp" -) - -type fakeMCPConnection struct { - mu sync.Mutex - writes []*sdkjsonrpc.Request - readCh chan sdkjsonrpc.Message - closed chan struct{} - closeMu sync.Once - onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) -} - -func newFakeMCPConnection(onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error)) *fakeMCPConnection { - return &fakeMCPConnection{ - writes: make([]*sdkjsonrpc.Request, 0, 16), - readCh: make(chan sdkjsonrpc.Message, 32), - closed: make(chan struct{}), - onWrite: onWrite, - } -} - -func (c *fakeMCPConnection) Read(ctx context.Context) (sdkjsonrpc.Message, error) { - select { - case <-c.closed: - return nil, io.EOF - case <-ctx.Done(): - return nil, ctx.Err() - case msg, ok := <-c.readCh: - if !ok { - return nil, io.EOF - } - return msg, nil - } -} - -func (c *fakeMCPConnection) Write(ctx context.Context, msg sdkjsonrpc.Message) error { - req, ok := msg.(*sdkjsonrpc.Request) - if !ok { - return fmt.Errorf("unsupported message type: %T", msg) - } - cloned := cloneJSONRPCRequest(req) - c.mu.Lock() - c.writes = append(c.writes, cloned) - c.mu.Unlock() - - if c.onWrite == nil { - return nil - } - resp, err := c.onWrite(cloned) - if err != nil { - return err - } - if resp == nil { - return nil - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-c.closed: - return io.EOF - case c.readCh <- resp: - return nil - } -} - -func (c *fakeMCPConnection) Close() error { - c.closeMu.Do(func() { - close(c.closed) - close(c.readCh) - }) - return nil -} - -func (c *fakeMCPConnection) SessionID() string { - return "test-session" -} - -func cloneJSONRPCRequest(req *sdkjsonrpc.Request) *sdkjsonrpc.Request { - if req == nil { - return nil - } - params := append([]byte(nil), req.Params...) - return &sdkjsonrpc.Request{ - ID: req.ID, - Method: req.Method, - Params: params, - Extra: req.Extra, - } -} - -func jsonRPCSuccessResponse(id sdkjsonrpc.ID, payload map[string]any) *sdkjsonrpc.Response { - body, _ := json.Marshal(payload) - return &sdkjsonrpc.Response{ - ID: id, - Result: body, - } -} - -func newTestMCPSession(conn *fakeMCPConnection) *mcpSession { - return &mcpSession{ - pending: map[string]chan *sdkjsonrpc.Response{}, - conn: conn, - closed: make(chan struct{}), - } -} - -func TestMCPSessionRetriesInitializeAfterFailure(t *testing.T) { - initCalls := 0 - conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { - switch req.Method { - case "initialize": - initCalls++ - if initCalls == 1 { - return &sdkjsonrpc.Response{ - ID: req.ID, - Error: &sdkjsonrpc.Error{ - Code: -32603, - Message: "temporary init failure", - }, - }, nil - } - return jsonRPCSuccessResponse(req.ID, map[string]any{ - "protocolVersion": "2025-06-18", - }), nil - case "tools/list": - return jsonRPCSuccessResponse(req.ID, map[string]any{ - "tools": []any{}, - }), nil - default: - return nil, nil - } - }) - session := newTestMCPSession(conn) - go session.readLoop() - defer session.closeWithError(io.EOF) - - _, firstErr := session.call(context.Background(), mcptools.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcptools.RawStringID("1"), - Method: "tools/list", - }) - if firstErr == nil { - t.Fatalf("first call should fail when initialize fails") - } - - secondPayload, secondErr := session.call(context.Background(), mcptools.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcptools.RawStringID("2"), - Method: "tools/list", - }) - if secondErr != nil { - t.Fatalf("second call should recover by retrying initialize: %v", secondErr) - } - if initCalls != 2 { - t.Fatalf("initialize should be retried once, got calls: %d", initCalls) - } - result, ok := secondPayload["result"].(map[string]any) - if !ok { - t.Fatalf("missing tools/list result: %#v", secondPayload) - } - if _, ok := result["tools"].([]any); !ok { - t.Fatalf("missing tools field: %#v", result) - } -} - -func TestMCPSessionExplicitInitializeDoesNotDuplicateInitialize(t *testing.T) { - initializeCalls := 0 - initializedNotifications := 0 - conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { - switch req.Method { - case "initialize": - initializeCalls++ - return jsonRPCSuccessResponse(req.ID, map[string]any{ - "protocolVersion": "2025-06-18", - }), nil - case "notifications/initialized": - initializedNotifications++ - return nil, nil - case "tools/list": - return jsonRPCSuccessResponse(req.ID, map[string]any{ - "tools": []any{}, - }), nil - default: - return nil, nil - } - }) - session := newTestMCPSession(conn) - go session.readLoop() - defer session.closeWithError(io.EOF) - - _, initErr := session.call(context.Background(), mcptools.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcptools.RawStringID("100"), - Method: "initialize", - Params: json.RawMessage(`{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"v1"}}`), - }) - if initErr != nil { - t.Fatalf("explicit initialize should succeed: %v", initErr) - } - - _, listErr := session.call(context.Background(), mcptools.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcptools.RawStringID("101"), - Method: "tools/list", - }) - if listErr != nil { - t.Fatalf("tools/list after initialize should succeed: %v", listErr) - } - if initializeCalls != 1 { - t.Fatalf("initialize should not be duplicated, got: %d", initializeCalls) - } - if initializedNotifications != 1 { - t.Fatalf("should send exactly one notifications/initialized, got: %d", initializedNotifications) - } -} - -func TestMCPSessionRemovesPendingOnContextCancel(t *testing.T) { - conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { - // Intentionally do not reply; caller should timeout. - return nil, nil - }) - session := newTestMCPSession(conn) - session.initState = mcpSessionInitStateReady - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) - defer cancel() - _, err := session.call(ctx, mcptools.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcptools.RawStringID("200"), - Method: "tools/list", - }) - if err == nil { - t.Fatalf("call should fail on context timeout") - } - - session.pendingMu.Lock() - pendingCount := len(session.pending) - session.pendingMu.Unlock() - if pendingCount != 0 { - t.Fatalf("pending map should be empty after cancellation, got: %d", pendingCount) - } -} diff --git a/internal/handlers/mcp_federation_gateway.go b/internal/handlers/mcp_federation_gateway.go index 6c5ed256..0955c903 100644 --- a/internal/handlers/mcp_federation_gateway.go +++ b/internal/handlers/mcp_federation_gateway.go @@ -279,9 +279,6 @@ func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, if err != nil { return nil, err } - if err := g.handler.validateMCPContainer(ctx, containerID, botID); err != nil { - return nil, err - } if err := g.handler.ensureContainerAndTask(ctx, containerID, botID); err != nil { return nil, err } diff --git a/internal/handlers/mcp_session_test.go b/internal/handlers/mcp_session_test.go new file mode 100644 index 00000000..aee97c5a --- /dev/null +++ b/internal/handlers/mcp_session_test.go @@ -0,0 +1,385 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sync" + "testing" + "time" + + sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" + + mcptools "github.com/memohai/memoh/internal/mcp" +) + +// fakeMCPConnection implements sdkmcp.Connection for testing. +// onWrite is called synchronously when Write is called; if it returns a +// non-nil Response the response is queued to be returned by Read. +type fakeMCPConnection struct { + mu sync.Mutex + writes []*sdkjsonrpc.Request + readCh chan sdkjsonrpc.Message + closed chan struct{} + closeMu sync.Once + onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) +} + +func newFakeMCPConnection(onWrite func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error)) *fakeMCPConnection { + return &fakeMCPConnection{ + writes: make([]*sdkjsonrpc.Request, 0, 16), + readCh: make(chan sdkjsonrpc.Message, 32), + closed: make(chan struct{}), + onWrite: onWrite, + } +} + +func (c *fakeMCPConnection) Read(ctx context.Context) (sdkjsonrpc.Message, error) { + select { + case <-c.closed: + return nil, io.EOF + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-c.readCh: + if !ok { + return nil, io.EOF + } + return msg, nil + } +} + +func (c *fakeMCPConnection) Write(ctx context.Context, msg sdkjsonrpc.Message) error { + req, ok := msg.(*sdkjsonrpc.Request) + if !ok { + return fmt.Errorf("unsupported message type: %T", msg) + } + cloned := cloneJSONRPCRequest(req) + c.mu.Lock() + c.writes = append(c.writes, cloned) + c.mu.Unlock() + + if c.onWrite == nil { + return nil + } + resp, err := c.onWrite(cloned) + if err != nil { + return err + } + if resp == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closed: + return io.EOF + case c.readCh <- resp: + return nil + } +} + +func (c *fakeMCPConnection) Close() error { + c.closeMu.Do(func() { + close(c.closed) + close(c.readCh) + }) + return nil +} + +func (c *fakeMCPConnection) SessionID() string { return "test-session" } + +func cloneJSONRPCRequest(req *sdkjsonrpc.Request) *sdkjsonrpc.Request { + if req == nil { + return nil + } + params := append([]byte(nil), req.Params...) + return &sdkjsonrpc.Request{ + ID: req.ID, + Method: req.Method, + Params: params, + Extra: req.Extra, + } +} + +func jsonRPCSuccessResponse(id sdkjsonrpc.ID, payload map[string]any) *sdkjsonrpc.Response { + body, _ := json.Marshal(payload) + return &sdkjsonrpc.Response{ID: id, Result: body} +} + +func newTestMCPSession(conn *fakeMCPConnection) *mcpSession { + readCtx, cancelRead := context.WithCancel(context.Background()) + return &mcpSession{ + pending: map[string]chan *sdkjsonrpc.Response{}, + conn: conn, + closed: make(chan struct{}), + readCtx: readCtx, + cancelRead: cancelRead, + } +} + +// --- Tests --- + +// TestMCPSession_CallRaw_ResponseEnvelope verifies that callRaw returns a +// standard JSON-RPC envelope {"jsonrpc","id","result"}. +func TestMCPSession_CallRaw_ResponseEnvelope(t *testing.T) { + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + return jsonRPCSuccessResponse(req.ID, map[string]any{"tools": []any{}}), nil + }) + sess := newTestMCPSession(conn) + sess.initState = mcpSessionInitStateReady + go sess.readLoop() + defer sess.closeWithError(io.EOF) + + payload, err := sess.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("1"), + Method: "tools/list", + }) + if err != nil { + t.Fatalf("call failed: %v", err) + } + + // Verify standard JSON-RPC envelope. + if payload["jsonrpc"] != "2.0" { + t.Errorf("expected jsonrpc=2.0, got %v", payload["jsonrpc"]) + } + if _, ok := payload["id"]; !ok { + t.Errorf("expected 'id' field in envelope, got %v", payload) + } + if _, ok := payload["result"]; !ok { + t.Errorf("expected 'result' field in envelope, got %v", payload) + } + if _, ok := payload["error"]; ok { + t.Errorf("unexpected 'error' field in success envelope") + } +} + +// TestMCPSession_CallRaw_ErrorEnvelope verifies that server-side errors are +// returned as {"jsonrpc","id","error"} envelope, not a Go error. +func TestMCPSession_CallRaw_ErrorEnvelope(t *testing.T) { + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + return &sdkjsonrpc.Response{ + ID: req.ID, + Error: &sdkjsonrpc.Error{Code: -32601, Message: "Method not found"}, + }, nil + }) + sess := newTestMCPSession(conn) + sess.initState = mcpSessionInitStateReady + go sess.readLoop() + defer sess.closeWithError(io.EOF) + + payload, err := sess.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("2"), + Method: "unknown/method", + }) + if err != nil { + t.Fatalf("unexpected Go error (server errors should be in envelope): %v", err) + } + errField, ok := payload["error"].(map[string]any) + if !ok { + t.Fatalf("expected 'error' field in envelope, got %v", payload) + } + if errField["code"] != int64(-32601) { + t.Errorf("unexpected error code: %v", errField["code"]) + } + if _, ok := payload["result"]; ok { + t.Errorf("unexpected 'result' field in error envelope") + } +} + +// TestMCPSession_InitializeRetryAfterFailure tests that the session retries +// the initialize handshake after the first attempt fails. +func TestMCPSession_InitializeRetryAfterFailure(t *testing.T) { + initCalls := 0 + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + switch req.Method { + case "initialize": + initCalls++ + if initCalls == 1 { + return &sdkjsonrpc.Response{ + ID: req.ID, + Error: &sdkjsonrpc.Error{Code: -32603, Message: "temporary init failure"}, + }, nil + } + return jsonRPCSuccessResponse(req.ID, map[string]any{"protocolVersion": "2025-06-18"}), nil + case "tools/list": + return jsonRPCSuccessResponse(req.ID, map[string]any{"tools": []any{}}), nil + default: + return nil, nil + } + }) + sess := newTestMCPSession(conn) + go sess.readLoop() + defer sess.closeWithError(io.EOF) + + _, firstErr := sess.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("1"), + Method: "tools/list", + }) + if firstErr == nil { + t.Fatal("first call should fail when initialize fails") + } + + secondPayload, secondErr := sess.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("2"), + Method: "tools/list", + }) + if secondErr != nil { + t.Fatalf("second call should recover by retrying initialize: %v", secondErr) + } + if initCalls != 2 { + t.Fatalf("initialize should be retried once, got calls: %d", initCalls) + } + result, ok := secondPayload["result"].(map[string]any) + if !ok { + t.Fatalf("missing tools/list result in envelope: %#v", secondPayload) + } + if _, ok := result["tools"].([]any); !ok { + t.Fatalf("missing tools field: %#v", result) + } +} + +// TestMCPSession_ExplicitInitializeNoDoubling tests that sending an explicit +// "initialize" call does not cause the session to auto-initialize again. +func TestMCPSession_ExplicitInitializeNoDoubling(t *testing.T) { + initializeCalls := 0 + initializedNotifications := 0 + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + switch req.Method { + case "initialize": + initializeCalls++ + return jsonRPCSuccessResponse(req.ID, map[string]any{"protocolVersion": "2025-06-18"}), nil + case "notifications/initialized": + initializedNotifications++ + return nil, nil + case "tools/list": + return jsonRPCSuccessResponse(req.ID, map[string]any{"tools": []any{}}), nil + default: + return nil, nil + } + }) + sess := newTestMCPSession(conn) + go sess.readLoop() + defer sess.closeWithError(io.EOF) + + _, initErr := sess.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("100"), + Method: "initialize", + Params: json.RawMessage(`{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"v1"}}`), + }) + if initErr != nil { + t.Fatalf("explicit initialize should succeed: %v", initErr) + } + + _, listErr := sess.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("101"), + Method: "tools/list", + }) + if listErr != nil { + t.Fatalf("tools/list after initialize should succeed: %v", listErr) + } + if initializeCalls != 1 { + t.Fatalf("initialize should not be duplicated, got: %d", initializeCalls) + } + if initializedNotifications != 1 { + t.Fatalf("should send exactly one notifications/initialized, got: %d", initializedNotifications) + } +} + +// TestMCPSession_PendingCleanupOnContextCancel tests that cancelling a request +// context removes it from the pending map. +func TestMCPSession_PendingCleanupOnContextCancel(t *testing.T) { + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + // Never reply — caller should time out. + return nil, nil + }) + sess := newTestMCPSession(conn) + sess.initState = mcpSessionInitStateReady + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + _, err := sess.call(ctx, mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("200"), + Method: "tools/list", + }) + if err == nil { + t.Fatal("call should fail on context timeout") + } + + sess.pendingMu.Lock() + pendingCount := len(sess.pending) + sess.pendingMu.Unlock() + if pendingCount != 0 { + t.Fatalf("pending map should be empty after cancellation, got: %d", pendingCount) + } +} + +// TestMCPSession_PendingCleanupOnClose tests that closing the session drains +// all pending channels (callers unblock). +func TestMCPSession_PendingCleanupOnClose(t *testing.T) { + conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + return nil, nil // never reply + }) + sess := newTestMCPSession(conn) + sess.initState = mcpSessionInitStateReady + + errCh := make(chan error, 1) + go func() { + _, err := sess.call(context.Background(), mcptools.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcptools.RawStringID("300"), + Method: "tools/list", + }) + errCh <- err + }() + + // Give goroutine time to register in pending. + time.Sleep(10 * time.Millisecond) + sess.closeWithError(io.EOF) + + select { + case err := <-errCh: + if err == nil { + t.Error("expected error after session close, got nil") + } + case <-time.After(2 * time.Second): + t.Fatal("call did not unblock after session close") + } + + sess.pendingMu.Lock() + pendingCount := len(sess.pending) + sess.pendingMu.Unlock() + if pendingCount != 0 { + t.Fatalf("pending map should be empty after close, got: %d", pendingCount) + } +} + +// TestMCPSession_ReadLoopCancelOnClose tests that closing the session +// (which cancels readCtx) causes readLoop to exit. +func TestMCPSession_ReadLoopCancelOnClose(t *testing.T) { + conn := newFakeMCPConnection(nil) + sess := newTestMCPSession(conn) + + loopDone := make(chan struct{}) + go func() { + sess.readLoop() + close(loopDone) + }() + + // Close the session; this should cancel readCtx and unblock readLoop. + sess.closeWithError(io.EOF) + + select { + case <-loopDone: + // readLoop exited as expected. + case <-time.After(2 * time.Second): + t.Fatal("readLoop did not exit after session close") + } +} diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go index 846bd06d..b22e1976 100644 --- a/internal/handlers/mcp_stdio.go +++ b/internal/handlers/mcp_stdio.go @@ -1,13 +1,17 @@ package handlers import ( + "bufio" "context" + "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" "sort" "strings" + "sync" "time" "github.com/google/uuid" @@ -15,10 +19,11 @@ import ( sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" - ctr "github.com/memohai/memoh/internal/containerd" mcptools "github.com/memohai/memoh/internal/mcp" + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" ) +// MCPStdioRequest represents a request to create an MCP stdio session. type MCPStdioRequest struct { Name string `json:"name"` Command string `json:"command"` @@ -27,12 +32,527 @@ type MCPStdioRequest struct { Cwd string `json:"cwd"` } +// MCPStdioResponse represents the response from creating an MCP stdio session. type MCPStdioResponse struct { ConnectionID string `json:"connection_id"` URL string `json:"url"` Tools []string `json:"tools,omitempty"` } +// mcpSession represents an MCP session over stdio. +type mcpSession struct { + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser + readCtx context.Context + cancelRead context.CancelFunc + initMu sync.Mutex + initState mcpSessionInitState + initWait chan struct{} + pendingMu sync.Mutex + pending map[string]chan *sdkjsonrpc.Response + conn sdkmcp.Connection + closed chan struct{} + closeOnce sync.Once + closeErr error + onClose func() +} + +type mcpSessionInitState uint8 + +const ( + mcpSessionInitStateNone mcpSessionInitState = iota + mcpSessionInitStateInitializing + mcpSessionInitStateInitialized + mcpSessionInitStateReady +) + +func (s *mcpSession) closeWithError(err error) { + s.closeOnce.Do(func() { + s.closeErr = err + close(s.closed) + if s.cancelRead != nil { + s.cancelRead() + } + s.pendingMu.Lock() + for _, ch := range s.pending { + close(ch) + } + s.pending = map[string]chan *sdkjsonrpc.Response{} + s.pendingMu.Unlock() + if s.conn != nil { + _ = s.conn.Close() + } + if s.stdin != nil { + _ = s.stdin.Close() + } + if s.stdout != nil { + _ = s.stdout.Close() + } + if s.stderr != nil { + _ = s.stderr.Close() + } + if s.onClose != nil { + s.onClose() + } + }) +} + +func (s *mcpSession) readLoop() { + if s.conn == nil { + s.closeWithError(io.EOF) + return + } + for { + msg, err := s.conn.Read(s.readCtx) + if err != nil { + if errors.Is(err, io.EOF) { + s.closeWithError(io.EOF) + return + } + s.closeWithError(err) + return + } + resp, ok := msg.(*sdkjsonrpc.Response) + if !ok || !resp.ID.IsValid() { + continue + } + id := sdkIDKey(resp.ID) + if id == "" { + continue + } + s.pendingMu.Lock() + ch, ok := s.pending[id] + if ok { + delete(s.pending, id) + } + s.pendingMu.Unlock() + if ok { + ch <- resp + close(ch) + } + } +} + +func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { + method := strings.TrimSpace(req.Method) + if method == "initialize" { + payload, err := s.callRaw(ctx, req) + if err != nil { + return nil, err + } + // If the server accepted our initialize, advance state so + // ensureInitialized will only send notifications/initialized next time. + if _, hasError := payload["error"]; !hasError { + s.initMu.Lock() + if s.initState < mcpSessionInitStateInitialized { + s.initState = mcpSessionInitStateInitialized + } + s.initMu.Unlock() + } + return payload, nil + } + if method != "notifications/initialized" { + if err := s.ensureInitialized(ctx); err != nil { + return nil, err + } + } + return s.callRaw(ctx, req) +} + +func (s *mcpSession) callRaw(ctx context.Context, req mcptools.JSONRPCRequest) (map[string]any, error) { + targetID, err := parseRawJSONRPCID(req.ID) + if err != nil { + return nil, err + } + target := sdkIDKey(targetID) + if target == "" { + return nil, fmt.Errorf("missing request id") + } + + respCh := make(chan *sdkjsonrpc.Response, 1) + s.pendingMu.Lock() + s.pending[target] = respCh + s.pendingMu.Unlock() + + callReq := &sdkjsonrpc.Request{ + ID: targetID, + Method: req.Method, + Params: req.Params, + } + if err := s.conn.Write(ctx, callReq); err != nil { + s.pendingMu.Lock() + delete(s.pending, target) + s.pendingMu.Unlock() + return nil, err + } + + select { + case resp, ok := <-respCh: + if !ok { + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + } + return sdkResponsePayload(resp) + case <-s.closed: + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + case <-ctx.Done(): + s.pendingMu.Lock() + delete(s.pending, target) + s.pendingMu.Unlock() + return nil, ctx.Err() + } +} + +// sdkResponsePayload wraps an SDK JSON-RPC response into a standard JSON-RPC +// envelope ({"jsonrpc":"2.0","id":...,"result":...} or "error":...). +func sdkResponsePayload(resp *sdkjsonrpc.Response) (map[string]any, error) { + if resp == nil { + return nil, io.EOF + } + if resp.Error != nil { + code := int64(-32603) + message := strings.TrimSpace(resp.Error.Error()) + if wireErr, ok := resp.Error.(*sdkjsonrpc.Error); ok { + code = wireErr.Code + message = strings.TrimSpace(wireErr.Message) + } + if message == "" { + message = "internal error" + } + return map[string]any{ + "jsonrpc": "2.0", + "id": sdkIDRaw(resp.ID), + "error": map[string]any{ + "code": code, + "message": message, + }, + }, nil + } + var result any + if len(resp.Result) > 0 { + if err := json.Unmarshal(resp.Result, &result); err != nil { + return nil, err + } + } + return map[string]any{ + "jsonrpc": "2.0", + "id": sdkIDRaw(resp.ID), + "result": result, + }, nil +} + +func sdkIDRaw(id sdkjsonrpc.ID) any { + if !id.IsValid() { + return nil + } + return id.Raw() +} + +func (s *mcpSession) notify(ctx context.Context, req mcptools.JSONRPCRequest) error { + if s.conn == nil { + return io.EOF + } + return s.conn.Write(ctx, &sdkjsonrpc.Request{ + Method: req.Method, + Params: req.Params, + }) +} + +func (s *mcpSession) ensureInitialized(ctx context.Context) error { + for { + s.initMu.Lock() + state := s.initState + + switch state { + case mcpSessionInitStateReady: + s.initMu.Unlock() + return nil + case mcpSessionInitStateInitializing: + waitCh := s.initWait + s.initMu.Unlock() + if waitCh == nil { + continue + } + select { + case <-waitCh: + continue + case <-ctx.Done(): + return ctx.Err() + case <-s.closed: + if s.closeErr != nil { + return s.closeErr + } + return io.EOF + } + case mcpSessionInitStateInitialized: + waitCh := make(chan struct{}) + s.initState = mcpSessionInitStateInitializing + s.initWait = waitCh + s.initMu.Unlock() + + err := s.sendInitializedNotification(ctx) + + s.initMu.Lock() + if err == nil { + s.initState = mcpSessionInitStateReady + } else { + s.initState = mcpSessionInitStateInitialized + } + s.initWait = nil + close(waitCh) + s.initMu.Unlock() + + if err != nil { + return err + } + return nil + default: + waitCh := make(chan struct{}) + s.initState = mcpSessionInitStateInitializing + s.initWait = waitCh + s.initMu.Unlock() + + nextState, err := s.initializeHandshake(ctx) + + s.initMu.Lock() + s.initState = nextState + s.initWait = nil + close(waitCh) + s.initMu.Unlock() + + if err != nil { + return err + } + if nextState == mcpSessionInitStateReady { + return nil + } + } + } +} + +func (s *mcpSession) initializeHandshake(ctx context.Context) (mcpSessionInitState, error) { + initID, _ := sdkjsonrpc.MakeID("init") + params, _ := json.Marshal(map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "memoh", + "version": "1.0.0", + }, + }) + initResp, err := s.invokeCall(ctx, &sdkjsonrpc.Request{ + ID: initID, + Method: "initialize", + Params: params, + }) + if err != nil { + return mcpSessionInitStateNone, err + } + if initResp.Error != nil { + return mcpSessionInitStateNone, initResp.Error + } + if err := s.sendInitializedNotification(ctx); err != nil { + return mcpSessionInitStateInitialized, err + } + return mcpSessionInitStateReady, nil +} + +func (s *mcpSession) sendInitializedNotification(ctx context.Context) error { + if s.conn == nil { + return io.EOF + } + return s.conn.Write(ctx, &sdkjsonrpc.Request{ + Method: "notifications/initialized", + }) +} + +func (s *mcpSession) invokeCall(ctx context.Context, req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + if s.conn == nil { + return nil, io.EOF + } + if req == nil || !req.ID.IsValid() { + return nil, fmt.Errorf("missing request id") + } + key := sdkIDKey(req.ID) + if key == "" { + return nil, fmt.Errorf("invalid request id") + } + + respCh := make(chan *sdkjsonrpc.Response, 1) + s.pendingMu.Lock() + s.pending[key] = respCh + s.pendingMu.Unlock() + + if err := s.conn.Write(ctx, req); err != nil { + s.removePending(key) + return nil, err + } + + select { + case resp, ok := <-respCh: + if !ok { + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + } + return resp, nil + case <-s.closed: + if s.closeErr != nil { + return nil, s.closeErr + } + return nil, io.EOF + case <-ctx.Done(): + s.removePending(key) + return nil, ctx.Err() + } +} + +func (s *mcpSession) removePending(key string) { + if strings.TrimSpace(key) == "" { + return + } + s.pendingMu.Lock() + delete(s.pending, key) + s.pendingMu.Unlock() +} + +func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { + if len(raw) == 0 { + return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + } + var idValue any + if err := json.Unmarshal(raw, &idValue); err != nil { + return sdkjsonrpc.ID{}, err + } + id, err := sdkjsonrpc.MakeID(idValue) + if err != nil { + return sdkjsonrpc.ID{}, err + } + if !id.IsValid() { + return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + } + return id, nil +} + +func sdkIDKey(id sdkjsonrpc.ID) string { + if !id.IsValid() { + return "" + } + raw, _ := json.Marshal(id.Raw()) + return string(raw) +} + +func startMCPStderrLogger(stderr io.ReadCloser, containerID string, logger *slog.Logger) { + if stderr == nil { + return + } + go func() { + scanner := bufio.NewScanner(stderr) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + logger.Warn("mcp stderr", slog.String("container_id", containerID), slog.String("message", line)) + } + if err := scanner.Err(); err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "closed pipe") { + return + } + logger.Error("mcp stderr read failed", slog.Any("error", err), slog.String("container_id", containerID)) + } + }() +} + +func extractToolNames(payload map[string]any) []string { + result, ok := payload["result"].(map[string]any) + if !ok { + return nil + } + rawTools, ok := result["tools"].([]any) + if !ok { + return nil + } + names := make([]string, 0, len(rawTools)) + for _, raw := range rawTools { + item, ok := raw.(map[string]any) + if !ok { + continue + } + name, _ := item["name"].(string) + name = strings.TrimSpace(name) + if name == "" { + continue + } + names = append(names, name) + } + sort.Strings(names) + return names +} + +func buildShellCommand(req MCPStdioRequest) string { + cmd := strings.TrimSpace(req.Command) + if cmd == "" { + return "" + } + parts := make([]string, 0, len(req.Args)+1) + parts = append(parts, escapeShellArg(cmd)) + for _, arg := range req.Args { + parts = append(parts, escapeShellArg(arg)) + } + command := strings.Join(parts, " ") + + assignments := []string{} + for _, pair := range buildEnvPairs(req.Env) { + assignments = append(assignments, escapeShellArg(pair)) + } + if len(assignments) > 0 { + command = strings.Join(assignments, " ") + " " + command + } + if strings.TrimSpace(req.Cwd) != "" { + command = "cd " + escapeShellArg(req.Cwd) + " && " + command + } + return command +} + +func escapeShellArg(value string) string { + if value == "" { + return "''" + } + if !strings.ContainsAny(value, " \t\n'\"\\$&;|<>*?()[]{}!`") { + return value + } + return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'" +} + +func buildEnvPairs(env map[string]string) []string { + if len(env) == 0 { + return nil + } + keys := make([]string, 0, len(env)) + for k := range env { + if strings.TrimSpace(k) != "" { + keys = append(keys, k) + } + } + sort.Strings(keys) + out := make([]string, 0, len(keys)) + for _, k := range keys { + out = append(out, fmt.Sprintf("%s=%s", k, env[k])) + } + return out +} + +// ---------- MCP Stdio Handlers ---------- + type mcpStdioSession struct { id string botID string @@ -71,9 +591,6 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") } - if err := h.validateMCPContainer(ctx, containerID, botID); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } if err := h.ensureContainerAndTask(ctx, containerID, botID); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -169,25 +686,82 @@ func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error { } func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context, containerID string, req MCPStdioRequest) (*mcpSession, error) { - args := append([]string{strings.TrimSpace(req.Command)}, req.Args...) - env := buildEnvPairs(req.Env) - execSession, err := h.service.ExecTaskStreaming(ctx, containerID, ctr.ExecTaskRequest{ - Args: args, - Env: env, - WorkDir: strings.TrimSpace(req.Cwd), - FIFODir: h.mcpFIFODir(), - }) + // Extract bot_id from container_id (remove "mcp-" prefix) + botID := strings.TrimPrefix(containerID, "mcp-") + if botID == "" || botID == containerID { + return nil, fmt.Errorf("invalid container_id: %s", containerID) + } + + // Get gRPC client for the bot container via manager + client, err := h.manager.MCPClient(ctx, botID) + if err != nil { + return nil, fmt.Errorf("get container client: %w", err) + } + + command := buildShellCommand(req) + + // Create bidirectional exec stream + execStream, err := client.ExecStream(ctx, command, strings.TrimSpace(req.Cwd), 0) if err != nil { return nil, err } + // Create pipes for stdin/stdout/stderr + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + stderrR, stderrW := io.Pipe() + + readCtx, cancelRead := context.WithCancel(context.Background()) sess := &mcpSession{ - stdin: execSession.Stdin, - stdout: execSession.Stdout, - stderr: execSession.Stderr, - pending: make(map[string]chan *sdkjsonrpc.Response), - closed: make(chan struct{}), + stdin: stdinW, + stdout: stdoutR, + stderr: stderrR, + readCtx: readCtx, + cancelRead: cancelRead, + pending: make(map[string]chan *sdkjsonrpc.Response), + closed: make(chan struct{}), } + + // Forward stdin to gRPC stream + go func() { + buf := make([]byte, 32*1024) + for { + n, err := stdinR.Read(buf) + if n > 0 { + _ = execStream.SendStdin(buf[:n]) + } + if err != nil { + break + } + } + _ = stdinR.Close() + }() + + // Forward gRPC stdout/stderr to pipes + go func() { + for { + output, err := execStream.Recv() + if err != nil { + if err != io.EOF { + h.logger.Debug("exec stream recv done", slog.Any("error", err)) + } + _ = stdoutW.Close() + _ = stderrW.Close() + break + } + switch output.GetStream() { + case pb.ExecOutput_STDOUT: + _, _ = stdoutW.Write(output.GetData()) + case pb.ExecOutput_STDERR: + _, _ = stderrW.Write(output.GetData()) + case pb.ExecOutput_EXIT: + _ = stdoutW.Close() + _ = stderrW.Close() + return + } + } + }() + transport := &sdkmcp.IOTransport{ Reader: sess.stdout, Writer: sess.stdin, @@ -198,42 +772,15 @@ func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context return nil, err } sess.conn = conn - h.startMCPStderrLogger(execSession.Stderr, containerID) + startMCPStderrLogger(sess.stderr, containerID, h.logger) go sess.readLoop() go func() { - _, err := execSession.Wait() - if err != nil { - if isBenignMCPSessionExit(err) { - sess.closeWithError(io.EOF) - return - } - h.logger.Error("mcp stdio session exited", slog.Any("error", err), slog.String("container_id", containerID)) - sess.closeWithError(err) - return - } - sess.closeWithError(io.EOF) + <-sess.closed + _ = execStream.Close() }() return sess, nil } -func buildEnvPairs(env map[string]string) []string { - if len(env) == 0 { - return nil - } - keys := make([]string, 0, len(env)) - for k := range env { - if strings.TrimSpace(k) != "" { - keys = append(keys, k) - } - } - sort.Strings(keys) - out := make([]string, 0, len(keys)) - for _, k := range keys { - out = append(out, fmt.Sprintf("%s=%s", k, env[k])) - } - return out -} - func (h *ContainerdHandler) probeMCPTools(ctx context.Context, sess *mcpSession, botID, name string) []string { if sess == nil { return nil @@ -268,64 +815,3 @@ func (h *ContainerdHandler) probeMCPTools(ctx context.Context, sess *mcpSession, } return tools } - -func extractToolNames(payload map[string]any) []string { - result, ok := payload["result"].(map[string]any) - if !ok { - return nil - } - rawTools, ok := result["tools"].([]any) - if !ok { - return nil - } - names := make([]string, 0, len(rawTools)) - for _, raw := range rawTools { - item, ok := raw.(map[string]any) - if !ok { - continue - } - name, _ := item["name"].(string) - name = strings.TrimSpace(name) - if name == "" { - continue - } - names = append(names, name) - } - sort.Strings(names) - return names -} - -func buildShellCommand(req MCPStdioRequest) string { - cmd := strings.TrimSpace(req.Command) - if cmd == "" { - return "" - } - parts := make([]string, 0, len(req.Args)+1) - parts = append(parts, escapeShellArg(cmd)) - for _, arg := range req.Args { - parts = append(parts, escapeShellArg(arg)) - } - command := strings.Join(parts, " ") - - assignments := []string{} - for _, pair := range buildEnvPairs(req.Env) { - assignments = append(assignments, escapeShellArg(pair)) - } - if len(assignments) > 0 { - command = strings.Join(assignments, " ") + " " + command - } - if strings.TrimSpace(req.Cwd) != "" { - command = "cd " + escapeShellArg(req.Cwd) + " && " + command - } - return command -} - -func escapeShellArg(value string) string { - if value == "" { - return "''" - } - if !strings.ContainsAny(value, " \t\n'\"\\$&;|<>*?()[]{}!`") { - return value - } - return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'" -} diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index 6dad81c9..aff87253 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -16,7 +16,7 @@ import ( "github.com/memohai/memoh/internal/accounts" "github.com/memohai/memoh/internal/bots" - fsops "github.com/memohai/memoh/internal/fs" + "github.com/memohai/memoh/internal/mcp/mcpclient" memprovider "github.com/memohai/memoh/internal/memory/provider" storefs "github.com/memohai/memoh/internal/memory/storefs" "github.com/memohai/memoh/internal/settings" @@ -115,13 +115,13 @@ func (h *MemoryHandler) resolveProvider(ctx context.Context, botID string) mempr return p } -// SetFSService sets the optional filesystem persistence layer. -func (h *MemoryHandler) SetFSService(fs *fsops.Service) { - if fs == nil { +// SetMCPClientProvider sets the gRPC client provider for filesystem persistence. +func (h *MemoryHandler) SetMCPClientProvider(p mcpclient.Provider) { + if p == nil { h.memoryStore = nil return } - h.memoryStore = storefs.New(fs) + h.memoryStore = storefs.New(p) } // Register registers chat-level memory routes. @@ -625,11 +625,11 @@ func (h *MemoryHandler) requireBotAccess(c echo.Context) (string, error) { } // NewBuiltinMemoryRuntime keeps provider architecture while using file memory backend. -func NewBuiltinMemoryRuntime(fs *fsops.Service) any { - if fs == nil { +func NewBuiltinMemoryRuntime(p mcpclient.Provider) any { + if p == nil { return nil } - return &fileMemoryRuntime{store: storefs.New(fs)} + return &fileMemoryRuntime{store: storefs.New(p)} } type fileMemoryRuntime struct { diff --git a/internal/handlers/skills.go b/internal/handlers/skills.go index 445e84e3..f94ee4c7 100644 --- a/internal/handlers/skills.go +++ b/internal/handlers/skills.go @@ -2,16 +2,20 @@ package handlers import ( "context" + "fmt" "net/http" - "os" "path" - "path/filepath" "strings" "github.com/labstack/echo/v4" "gopkg.in/yaml.v3" + + "github.com/memohai/memoh/internal/config" + "github.com/memohai/memoh/internal/mcp/mcpclient" ) +const skillsDirPath = config.DefaultDataMount + "/.skills" + type SkillItem struct { Name string `json:"name"` Description string `json:"description"` @@ -50,35 +54,13 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error { if err != nil { return err } - skillsDir, err := h.ensureSkillsDirHost(botID) + skills, err := h.loadSkillsFromContainer(c.Request().Context(), botID) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - entries, err := listSkillEntries(skillsDir) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + for i := range skills { + skills[i].Raw = skills[i].Content } - - skills := make([]SkillItem, 0, len(entries)) - for _, entry := range entries { - skillPath, name := skillPathForEntry(entry) - if skillPath == "" { - continue - } - raw, err := h.readSkillFile(skillsDir, skillPath) - if err != nil { - continue - } - parsed := parseSkillFile(raw, name) - skills = append(skills, SkillItem{ - Name: parsed.Name, - Description: parsed.Description, - Content: parsed.Content, - Metadata: parsed.Metadata, - Raw: raw, - }) - } - return c.JSON(http.StatusOK, SkillsResponse{Skills: skills}) } @@ -105,22 +87,24 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "skills is required") } - skillsDir, err := h.ensureSkillsDirHost(botID) + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("container not reachable: %v", err)) } + for _, raw := range req.Skills { parsed := parseSkillFile(raw, "") if !isValidSkillName(parsed.Name) { return echo.NewHTTPError(http.StatusBadRequest, "skill must have a valid name in YAML frontmatter") } - dirPath := filepath.Join(skillsDir, parsed.Name) - if err := os.MkdirAll(dirPath, 0o755); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + dirPath := path.Join(skillsDirPath, parsed.Name) + if err := client.Mkdir(ctx, dirPath); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("mkdir failed: %v", err)) } - filePath := filepath.Join(dirPath, "SKILL.md") - if err := os.WriteFile(filePath, []byte(raw), 0o644); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + filePath := path.Join(dirPath, "SKILL.md") + if err := client.WriteFile(ctx, filePath, []byte(raw)); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("write failed: %v", err)) } } @@ -150,9 +134,10 @@ func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "names is required") } - skillsDir, err := h.ensureSkillsDirHost(botID) + ctx := c.Request().Context() + client, err := h.getGRPCClient(ctx, botID) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("container not reachable: %v", err)) } for _, name := range req.Names { @@ -160,154 +145,85 @@ func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { if !isValidSkillName(skillName) { return echo.NewHTTPError(http.StatusBadRequest, "invalid skill name") } - deletePath := filepath.Join(skillsDir, skillName) - if err := os.RemoveAll(deletePath); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } + _ = client.DeleteFile(ctx, path.Join(skillsDirPath, skillName), true) } return c.JSON(http.StatusOK, skillsOpResponse{OK: true}) } // LoadSkills loads all skills from the container for the given bot. -// This implements chat.SkillLoader. func (h *ContainerdHandler) LoadSkills(ctx context.Context, botID string) ([]SkillItem, error) { - skillsDir, err := h.ensureSkillsDirHost(botID) + return h.loadSkillsFromContainer(ctx, botID) +} + +func (h *ContainerdHandler) loadSkillsFromContainer(ctx context.Context, botID string) ([]SkillItem, error) { + client, err := h.getGRPCClient(ctx, botID) if err != nil { return nil, err } - entries, err := listSkillEntries(skillsDir) + entries, err := client.ListDir(ctx, skillsDirPath, false) if err != nil { - return nil, err + return []SkillItem{}, nil } - skills := make([]SkillItem, 0, len(entries)) + var skills []SkillItem for _, entry := range entries { - skillPath, name := skillPathForEntry(entry) - if skillPath == "" { + if !entry.GetIsDir() { + if path.Base(entry.GetPath()) == "SKILL.md" { + filePath := path.Join(skillsDirPath, "SKILL.md") + raw, readErr := readContainerSkillFile(ctx, client, filePath) + if readErr != nil { + continue + } + parsed := parseSkillFile(raw, "default") + skills = append(skills, skillItemFromParsed(parsed, raw)) + } continue } - raw, err := h.readSkillFile(skillsDir, skillPath) - if err != nil { + name := path.Base(entry.GetPath()) + if name == "" || name == "." { + continue + } + filePath := path.Join(skillsDirPath, name, "SKILL.md") + raw, readErr := readContainerSkillFile(ctx, client, filePath) + if readErr != nil { continue } parsed := parseSkillFile(raw, name) - skills = append(skills, SkillItem{ - Name: parsed.Name, - Description: parsed.Description, - Content: parsed.Content, - Metadata: parsed.Metadata, - }) + skills = append(skills, skillItemFromParsed(parsed, raw)) } return skills, nil } -func (h *ContainerdHandler) ensureSkillsDirHost(botID string) (string, error) { - root, err := h.ensureBotDataRoot(botID) +func readContainerSkillFile(ctx context.Context, client *mcpclient.Client, filePath string) (string, error) { + resp, err := client.ReadFile(ctx, filePath, 0, 0) if err != nil { return "", err } - skillsDir := filepath.Join(root, ".skills") - if err := os.MkdirAll(skillsDir, 0o755); err != nil { - return "", err - } - return skillsDir, nil + return resp.GetContent(), nil } -func (h *ContainerdHandler) readSkillFile(skillsDir, filePath string) (string, error) { - safeRel := strings.TrimPrefix(strings.TrimPrefix(filePath, ".skills/"), "./.skills/") - if safeRel == "" { - return "", os.ErrInvalid +func skillItemFromParsed(parsed parsedSkill, raw string) SkillItem { + return SkillItem{ + Name: parsed.Name, + Description: parsed.Description, + Content: parsed.Content, + Metadata: parsed.Metadata, + Raw: raw, } - target := filepath.Join(skillsDir, filepath.FromSlash(safeRel)) - data, err := os.ReadFile(target) - if err != nil { - return "", err - } - return string(data), nil } -func listSkillEntries(skillsDir string) ([]skillEntry, error) { - dirEntries, err := os.ReadDir(skillsDir) - if err != nil { - return nil, err - } - entries := make([]skillEntry, 0, len(dirEntries)) - for _, entry := range dirEntries { - name := entry.Name() - if name == "" { - continue - } - if entry.IsDir() { - entries = append(entries, skillEntry{ - Path: path.Join(".skills", name), - IsDir: true, - }) - continue - } - if name == "SKILL.md" { - entries = append(entries, skillEntry{ - Path: path.Join(".skills", name), - IsDir: false, - }) - } - } - return entries, nil -} +// --- parsing logic (unchanged) --- -type skillEntry struct { - Path string - IsDir bool -} - -func skillNameFromPath(rel string) string { - if rel == "" || rel == "SKILL.md" { - return "default" - } - parent := path.Dir(rel) - if parent == "." { - return "default" - } - return path.Base(parent) -} - -func skillPathForEntry(entry skillEntry) (string, string) { - rel := strings.TrimPrefix(entry.Path, ".skills/") - if rel == entry.Path { - rel = strings.TrimPrefix(entry.Path, "./.skills/") - } - if entry.IsDir { - name := path.Base(rel) - if name == "." || name == "" { - return "", "" - } - return path.Join(".skills", name, "SKILL.md"), name - } - if path.Base(rel) == "SKILL.md" { - return path.Join(".skills", "SKILL.md"), skillNameFromPath(rel) - } - return "", "" -} - -// parsedSkill holds the result of parsing a SKILL.md file with YAML frontmatter. type parsedSkill struct { Name string Description string - Content string // body after frontmatter - Metadata map[string]any // "metadata" key from frontmatter + Content string + Metadata map[string]any } // parseSkillFile parses a SKILL.md file with YAML frontmatter delimited by "---". -// Format: -// -// --- -// name: your-skill-name -// description: Brief description -// metadata: -// key: value -// --- -// # Body content ... func parseSkillFile(raw string, fallbackName string) parsedSkill { trimmed := strings.TrimSpace(raw) result := parsedSkill{ @@ -318,7 +234,6 @@ func parseSkillFile(raw string, fallbackName string) parsedSkill { return normalizeParsedSkill(result) } - // Find closing "---". rest := trimmed[3:] rest = strings.TrimLeft(rest, " \t") if len(rest) > 0 && rest[0] == '\n' { diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index dc6cae4a..b1223c62 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -1,12 +1,10 @@ package mcp import ( - "bytes" "context" "fmt" "log/slog" - "os" - "path/filepath" + "strings" "sync" "time" @@ -18,6 +16,7 @@ import ( ctr "github.com/memohai/memoh/internal/containerd" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" "github.com/memohai/memoh/internal/identity" + "github.com/memohai/memoh/internal/mcp/mcpclient" ) const ( @@ -25,26 +24,6 @@ const ( ContainerPrefix = "mcp-" ) -type ExecRequest struct { - BotID string - Command []string - Env []string - WorkDir string - Terminal bool - UseStdio bool -} - -type ExecResult struct { - ExitCode uint32 -} - -// ExecWithCaptureResult holds stdout, stderr and exit code from container exec. -type ExecWithCaptureResult struct { - Stdout string - Stderr string - ExitCode uint32 -} - type Manager struct { service ctr.Service cfg config.MCPConfig @@ -55,13 +34,16 @@ type Manager struct { logger *slog.Logger containerLockMu sync.Mutex containerLocks map[string]*sync.Mutex + mu sync.RWMutex + containerIPs map[string]string + grpcPool *mcpclient.Pool } func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, conn *pgxpool.Pool) *Manager { if namespace == "" { namespace = config.DefaultNamespace } - return &Manager{ + m := &Manager{ service: service, cfg: cfg, namespace: namespace, @@ -69,10 +51,13 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, nam queries: dbsqlc.New(conn), logger: log.With(slog.String("component", "mcp")), containerLocks: make(map[string]*sync.Mutex), + containerIPs: make(map[string]string), containerID: func(botID string) string { return ContainerPrefix + botID }, } + m.grpcPool = mcpclient.NewPool(m.ContainerIP) + return m } func (m *Manager) lockContainer(containerID string) func() { @@ -88,6 +73,85 @@ func (m *Manager) lockContainer(containerID string) func() { return lock.Unlock } +// ContainerIP returns the cached IP address for a bot's container. +// If not cached, it attempts to recover the IP by re-running CNI setup. +func (m *Manager) ContainerIP(botID string) string { + m.mu.RLock() + if ip, ok := m.containerIPs[botID]; ok { + m.mu.RUnlock() + return ip + } + m.mu.RUnlock() + + // Cache miss - try to recover IP via CNI setup (idempotent) + ip, err := m.recoverContainerIP(botID) + if err != nil { + m.logger.Warn("container IP recovery failed", slog.String("bot_id", botID), slog.Any("error", err)) + return "" + } + if ip != "" { + m.mu.Lock() + m.containerIPs[botID] = ip + m.mu.Unlock() + m.logger.Info("container IP recovered", slog.String("bot_id", botID), slog.String("ip", ip)) + } + return ip +} + +// SetContainerIP stores the container IP in the cache. +// If the IP changed, the stale gRPC connection is evicted from the pool. +func (m *Manager) SetContainerIP(botID, ip string) { + if ip == "" { + return + } + m.mu.Lock() + old := m.containerIPs[botID] + m.containerIPs[botID] = ip + m.mu.Unlock() + + if old != "" && old != ip { + m.grpcPool.Remove(botID) + m.logger.Info("evicted stale gRPC connection", slog.String("bot_id", botID), slog.String("old_ip", old), slog.String("new_ip", ip)) + } +} + +// recoverContainerIP attempts to restore the container IP by re-running CNI setup. +// CNI plugins are idempotent - calling Setup again returns the existing IP allocation. +func (m *Manager) recoverContainerIP(botID string) (string, error) { + ctx := context.Background() + containerID := m.containerID(botID) + + // First check if container exists and get basic info + info, err := m.service.GetContainer(ctx, containerID) + if err != nil { + return "", err + } + + // Check if IP is stored in labels (if we ever add label persistence) + if ip, ok := info.Labels["mcp.container_ip"]; ok { + return ip, nil + } + + // Container exists but IP not cached - need to re-setup network to get IP + // This happens after server restart when in-memory cache is lost + netResult, err := m.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ + ContainerID: containerID, + CNIBinDir: m.cfg.CNIBinaryDir, + CNIConfDir: m.cfg.CNIConfigDir, + }) + if err != nil { + return "", fmt.Errorf("network setup for IP recovery: %w", err) + } + + return netResult.IP, nil +} + +// MCPClient returns a gRPC client for the given bot's container. +// Implements mcpclient.Provider. +func (m *Manager) MCPClient(ctx context.Context, botID string) (*mcpclient.Client, error) { + return m.grpcPool.Get(ctx, botID) +} + func (m *Manager) Init(ctx context.Context) error { image := m.imageRef() @@ -103,30 +167,19 @@ func (m *Manager) Init(ctx context.Context) error { } // EnsureBot creates the MCP container for a bot if it does not exist. +// Bot data lives in the container's writable layer (snapshot), not bind mounts. func (m *Manager) EnsureBot(ctx context.Context, botID string) error { if err := validateBotID(botID); err != nil { return err } - dataDir, err := m.ensureBotDir(botID) - if err != nil { - return err - } - - dataMount := m.dataMount() image := m.imageRef() - resolvPath, err := ctr.ResolveConfSource(dataDir) + resolvPath, err := ctr.ResolveConfSource(m.dataRoot()) if err != nil { return err } mounts := []ctr.MountSpec{ - { - Destination: dataMount, - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -183,21 +236,31 @@ func (m *Manager) Start(ctx context.Context, botID string) error { return err } - if err := m.service.StartContainer(ctx, m.containerID(botID), &ctr.StartTaskOptions{ - UseStdio: false, - }); err != nil { + if err := m.service.StartContainer(ctx, m.containerID(botID), nil); err != nil { return err } - if err := m.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ + netResult, err := m.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ ContainerID: m.containerID(botID), CNIBinDir: m.cfg.CNIBinaryDir, CNIConfDir: m.cfg.CNIConfigDir, - }); err != nil { + }) + if err != nil { if stopErr := m.service.StopContainer(ctx, m.containerID(botID), &ctr.StopTaskOptions{Force: true}); stopErr != nil { m.logger.Warn("cleanup: stop task failed", slog.String("container_id", m.containerID(botID)), slog.Any("error", stopErr)) } return err } + if netResult.IP != "" { + m.mu.Lock() + m.containerIPs[botID] = netResult.IP + m.mu.Unlock() + m.logger.Info("container network ready", slog.String("bot_id", botID), slog.String("ip", netResult.IP)) + + // Run migration in the background so Start() returns immediately. + // Migration uses its own context so it isn't cancelled when the + // caller's HTTP request finishes. + go m.migrateBindMountData(context.WithoutCancel(ctx), botID) + } return nil } @@ -231,105 +294,6 @@ func (m *Manager) Delete(ctx context.Context, botID string) error { }) } -func (m *Manager) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) { - if err := validateBotID(req.BotID); err != nil { - return nil, err - } - if len(req.Command) == 0 { - return nil, fmt.Errorf("%w: empty command", ctr.ErrInvalidArgument) - } - if m.queries == nil { - return nil, fmt.Errorf("db is not configured") - } - - startedAt := time.Now() - if _, err := m.CreateVersion(ctx, req.BotID); err != nil { - return nil, err - } - - result, err := m.service.ExecTask(ctx, m.containerID(req.BotID), ctr.ExecTaskRequest{ - Args: req.Command, - Env: req.Env, - WorkDir: req.WorkDir, - Terminal: req.Terminal, - UseStdio: req.UseStdio, - }) - if err != nil { - return nil, err - } - - if err := m.insertEvent(ctx, m.containerID(req.BotID), "exec", map[string]any{ - "bot_id": req.BotID, - "command": req.Command, - "work_dir": req.WorkDir, - "exit_code": result.ExitCode, - "duration": time.Since(startedAt).String(), - }); err != nil { - return nil, err - } - - return &ExecResult{ExitCode: result.ExitCode}, nil -} - -// ExecWithCapture runs a command in the bot container and returns stdout, stderr and exit code. -// Use this when the caller needs command output (e.g. MCP exec tool). -// The container must already be running; use Start(botID) or the container/start API to start it. -func (m *Manager) ExecWithCapture(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { - if err := validateBotID(req.BotID); err != nil { - return nil, err - } - if len(req.Command) == 0 { - return nil, fmt.Errorf("%w: empty command", ctr.ErrInvalidArgument) - } - if m.queries == nil { - return nil, fmt.Errorf("db is not configured") - } - return m.execWithCaptureContainerd(ctx, req) -} - -func (m *Manager) execWithCaptureContainerd(ctx context.Context, req ExecRequest) (*ExecWithCaptureResult, error) { - fifoDir, err := os.MkdirTemp(m.dataRoot(), "exec-fifo-") - if err != nil { - return nil, fmt.Errorf("create fifo dir: %w", err) - } - defer os.RemoveAll(fifoDir) - - var stdoutBuf, stderrBuf bytes.Buffer - result, err := m.service.ExecTask(ctx, m.containerID(req.BotID), ctr.ExecTaskRequest{ - Args: req.Command, - Env: req.Env, - WorkDir: req.WorkDir, - Stderr: &stderrBuf, - Stdout: &stdoutBuf, - FIFODir: fifoDir, - }) - if err != nil { - return nil, err - } - return &ExecWithCaptureResult{ - Stdout: stdoutBuf.String(), - Stderr: stderrBuf.String(), - ExitCode: result.ExitCode, - }, nil -} - -// DataDir returns the host data directory for a bot. -func (m *Manager) DataDir(botID string) (string, error) { - if err := validateBotID(botID); err != nil { - return "", err - } - - return filepath.Join(m.dataRoot(), "bots", botID), nil -} - -func (m *Manager) ensureBotDir(botID string) (string, error) { - dir := filepath.Join(m.dataRoot(), "bots", botID) - if err := os.MkdirAll(dir, 0o755); err != nil { - return "", err - } - return dir, nil -} - func (m *Manager) dataRoot() string { if m.cfg.DataRoot == "" { return config.DefaultDataRoot @@ -337,10 +301,6 @@ func (m *Manager) dataRoot() string { return m.cfg.DataRoot } -func (m *Manager) dataMount() string { - return config.DefaultDataMount -} - func (m *Manager) imageRef() string { return m.cfg.ImageRef() } diff --git a/internal/mcp/mcpclient/client.go b/internal/mcp/mcpclient/client.go new file mode 100644 index 00000000..d0a2c113 --- /dev/null +++ b/internal/mcp/mcpclient/client.go @@ -0,0 +1,363 @@ +// Package mcpclient provides a gRPC client for the MCP container service. +// Each bot container runs a gRPC server on port 9090 exposing file and exec +// operations. This client wraps the generated gRPC stubs with connection +// pooling and a simplified API for callers. +package mcpclient + +import ( + "bytes" + "context" + "fmt" + "io" + "sync" + + "github.com/memohai/memoh/internal/config" + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" +) + +// Client wraps a gRPC connection to a single MCP container. +type Client struct { + conn *grpc.ClientConn + svc pb.ContainerServiceClient + target string +} + +// NewClientFromConn wraps an existing gRPC connection into a Client. +// Intended for testing with in-process transports such as bufconn. +func NewClientFromConn(conn *grpc.ClientConn) *Client { + return &Client{ + conn: conn, + svc: pb.NewContainerServiceClient(conn), + target: conn.Target(), + } +} + +// Dial creates a new Client connected to the given container IP. +func Dial(ctx context.Context, ip string) (*Client, error) { + target := fmt.Sprintf("%s:%d", ip, config.MCPGRPCPort) + conn, err := grpc.NewClient(target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("grpc dial %s: %w", target, err) + } + return &Client{ + conn: conn, + svc: pb.NewContainerServiceClient(conn), + target: target, + }, nil +} + +func (c *Client) Close() error { + return c.conn.Close() +} + +func (c *Client) ReadFile(ctx context.Context, path string, lineOffset, nLines int32) (*pb.ReadFileResponse, error) { + resp, err := c.svc.ReadFile(ctx, &pb.ReadFileRequest{ + Path: path, + LineOffset: lineOffset, + NLines: nLines, + }) + return resp, mapError(err) +} + +func (c *Client) WriteFile(ctx context.Context, path string, content []byte) error { + _, err := c.svc.WriteFile(ctx, &pb.WriteFileRequest{ + Path: path, + Content: content, + }) + return mapError(err) +} + +func (c *Client) ListDir(ctx context.Context, path string, recursive bool) ([]*pb.FileEntry, error) { + resp, err := c.svc.ListDir(ctx, &pb.ListDirRequest{ + Path: path, + Recursive: recursive, + }) + if err != nil { + return nil, mapError(err) + } + return resp.GetEntries(), nil +} + +func (c *Client) Stat(ctx context.Context, path string) (*pb.FileEntry, error) { + resp, err := c.svc.Stat(ctx, &pb.StatRequest{Path: path}) + if err != nil { + return nil, mapError(err) + } + return resp.GetEntry(), nil +} + +func (c *Client) Mkdir(ctx context.Context, path string) error { + _, err := c.svc.Mkdir(ctx, &pb.MkdirRequest{Path: path}) + return mapError(err) +} + +func (c *Client) Rename(ctx context.Context, oldPath, newPath string) error { + _, err := c.svc.Rename(ctx, &pb.RenameRequest{OldPath: oldPath, NewPath: newPath}) + return mapError(err) +} + +// ExecResult holds the output of a non-streaming exec call. +type ExecResult struct { + Stdout string + Stderr string + ExitCode int32 +} + +// Exec runs a command and collects all output. For streaming, use ExecStream. +func (c *Client) Exec(ctx context.Context, command, workDir string, timeout int32) (*ExecResult, error) { + return c.ExecWithStdin(ctx, command, workDir, timeout, nil) +} + +// ExecWithStdin runs a command with optional stdin data. +func (c *Client) ExecWithStdin(ctx context.Context, command, workDir string, timeout int32, stdinData []byte) (*ExecResult, error) { + stream, err := c.svc.Exec(ctx) + if err != nil { + return nil, mapError(err) + } + + // Send config message first + err = stream.Send(&pb.ExecInput{ + Command: command, + WorkDir: workDir, + TimeoutSeconds: timeout, + StdinData: stdinData, + }) + if err != nil { + return nil, err + } + + var stdout, stderr bytes.Buffer + var exitCode int32 + + for { + msg, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + switch msg.GetStream() { + case pb.ExecOutput_STDOUT: + stdout.Write(msg.GetData()) + case pb.ExecOutput_STDERR: + stderr.Write(msg.GetData()) + case pb.ExecOutput_EXIT: + exitCode = msg.GetExitCode() + } + } + + return &ExecResult{ + Stdout: stdout.String(), + Stderr: stderr.String(), + ExitCode: exitCode, + }, nil +} + +// ExecStream returns a bidirectional stream for interactive exec. +// Caller can send stdin data and receive stdout/stderr in real-time. +func (c *Client) ExecStream(ctx context.Context, command, workDir string, timeout int32) (*ExecStream, error) { + stream, err := c.svc.Exec(ctx) + if err != nil { + return nil, mapError(err) + } + + // Send config message first + err = stream.Send(&pb.ExecInput{ + Command: command, + WorkDir: workDir, + TimeoutSeconds: timeout, + }) + if err != nil { + return nil, err + } + + return &ExecStream{stream: stream}, nil +} + +// ExecStream wraps a bidirectional exec stream. +type ExecStream struct { + stream pb.ContainerService_ExecClient +} + +// SendStdin sends data to the process stdin. +func (s *ExecStream) SendStdin(data []byte) error { + return s.stream.Send(&pb.ExecInput{ + StdinData: data, + }) +} + +// Recv receives output from the process. +func (s *ExecStream) Recv() (*pb.ExecOutput, error) { + return s.stream.Recv() +} + +// Close closes the stream. +func (s *ExecStream) Close() error { + return s.stream.CloseSend() +} + +// ReadRaw streams raw file bytes. Caller must consume the returned reader. +func (c *Client) ReadRaw(ctx context.Context, path string) (io.ReadCloser, error) { + stream, err := c.svc.ReadRaw(ctx, &pb.ReadRawRequest{Path: path}) + if err != nil { + return nil, mapError(err) + } + return &streamReader{stream: stream}, nil +} + +// WriteRaw writes raw bytes to a file in the container. +func (c *Client) WriteRaw(ctx context.Context, path string, r io.Reader) (int64, error) { + stream, err := c.svc.WriteRaw(ctx) + if err != nil { + return 0, mapError(err) + } + + buf := make([]byte, 64*1024) + first := true + for { + n, readErr := r.Read(buf) + if n > 0 { + chunk := &pb.WriteRawChunk{Data: buf[:n]} + if first { + chunk.Path = path + first = false + } + if sendErr := stream.Send(chunk); sendErr != nil { + return 0, sendErr + } + } + if readErr == io.EOF { + break + } + if readErr != nil { + return 0, readErr + } + } + + resp, err := stream.CloseAndRecv() + if err != nil { + return 0, err + } + return resp.GetBytesWritten(), nil +} + +func (c *Client) DeleteFile(ctx context.Context, path string, recursive bool) error { + _, err := c.svc.DeleteFile(ctx, &pb.DeleteFileRequest{ + Path: path, + Recursive: recursive, + }) + return mapError(err) +} + +// streamReader adapts a gRPC server stream into an io.ReadCloser. +type streamReader struct { + stream pb.ContainerService_ReadRawClient + buf []byte + off int +} + +func (r *streamReader) Read(p []byte) (int, error) { + for r.off >= len(r.buf) { + msg, err := r.stream.Recv() + if err != nil { + return 0, err + } + r.buf = msg.GetData() + r.off = 0 + } + n := copy(p, r.buf[r.off:]) + r.off += n + return n, nil +} + +func (r *streamReader) Close() error { + return nil +} + +// Provider resolves a gRPC client for a given bot container. +type Provider interface { + MCPClient(ctx context.Context, botID string) (*Client, error) +} + +// Pool manages cached gRPC clients keyed by bot ID. +type Pool struct { + mu sync.RWMutex + clients map[string]*Client + ipFunc func(botID string) string +} + +// NewPool creates a client pool. ipFunc maps bot ID to container IP. +func NewPool(ipFunc func(string) string) *Pool { + return &Pool{ + clients: make(map[string]*Client), + ipFunc: ipFunc, + } +} + +// MCPClient implements Provider. Alias for Get. +func (p *Pool) MCPClient(ctx context.Context, botID string) (*Client, error) { + return p.Get(ctx, botID) +} + +// Get returns a cached client or dials a new one. +// Stale connections (Shutdown / TransientFailure) are evicted automatically. +func (p *Pool) Get(ctx context.Context, botID string) (*Client, error) { + p.mu.RLock() + if c, ok := p.clients[botID]; ok { + state := c.conn.GetState() + if state != connectivity.Shutdown && state != connectivity.TransientFailure { + p.mu.RUnlock() + return c, nil + } + p.mu.RUnlock() + p.Remove(botID) + } else { + p.mu.RUnlock() + } + + ip := p.ipFunc(botID) + if ip == "" { + return nil, fmt.Errorf("no IP for bot %s", botID) + } + + c, err := Dial(ctx, ip) + if err != nil { + return nil, err + } + + p.mu.Lock() + if existing, ok := p.clients[botID]; ok { + p.mu.Unlock() + c.Close() + return existing, nil + } + p.clients[botID] = c + p.mu.Unlock() + return c, nil +} + +// Remove closes and removes the client for a bot. +func (p *Pool) Remove(botID string) { + p.mu.Lock() + if c, ok := p.clients[botID]; ok { + c.Close() + delete(p.clients, botID) + } + p.mu.Unlock() +} + +// CloseAll closes all cached clients. +func (p *Pool) CloseAll() { + p.mu.Lock() + for id, c := range p.clients { + c.Close() + delete(p.clients, id) + } + p.mu.Unlock() +} diff --git a/internal/mcp/mcpclient/errors.go b/internal/mcp/mcpclient/errors.go new file mode 100644 index 00000000..f7ba2d42 --- /dev/null +++ b/internal/mcp/mcpclient/errors.go @@ -0,0 +1,41 @@ +package mcpclient + +import ( + "errors" + "fmt" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var ( + ErrNotFound = errors.New("not found") + ErrUnavailable = errors.New("unavailable") + ErrBadRequest = errors.New("invalid argument") + ErrForbidden = errors.New("permission denied") +) + +// mapError converts a gRPC status error into a domain error. +// Non-gRPC errors pass through unchanged. +func mapError(err error) error { + if err == nil { + return nil + } + s, ok := status.FromError(err) + if !ok { + return err + } + msg := s.Message() + switch s.Code() { + case codes.NotFound: + return fmt.Errorf("%w: %s", ErrNotFound, msg) + case codes.InvalidArgument: + return fmt.Errorf("%w: %s", ErrBadRequest, msg) + case codes.PermissionDenied: + return fmt.Errorf("%w: %s", ErrForbidden, msg) + case codes.Unavailable, codes.Aborted: + return fmt.Errorf("%w: %s", ErrUnavailable, msg) + default: + return fmt.Errorf("grpc %s: %s", s.Code(), msg) + } +} diff --git a/internal/mcp/mcpcontainer/mcpcontainer.pb.go b/internal/mcp/mcpcontainer/mcpcontainer.pb.go new file mode 100644 index 00000000..f1808123 --- /dev/null +++ b/internal/mcp/mcpcontainer/mcpcontainer.pb.go @@ -0,0 +1,1294 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v4.25.3 +// source: internal/mcp/mcpcontainer/mcpcontainer.proto + +package mcpcontainer + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ExecOutput_Stream int32 + +const ( + ExecOutput_STDOUT ExecOutput_Stream = 0 + ExecOutput_STDERR ExecOutput_Stream = 1 + ExecOutput_EXIT ExecOutput_Stream = 2 +) + +// Enum value maps for ExecOutput_Stream. +var ( + ExecOutput_Stream_name = map[int32]string{ + 0: "STDOUT", + 1: "STDERR", + 2: "EXIT", + } + ExecOutput_Stream_value = map[string]int32{ + "STDOUT": 0, + "STDERR": 1, + "EXIT": 2, + } +) + +func (x ExecOutput_Stream) Enum() *ExecOutput_Stream { + p := new(ExecOutput_Stream) + *p = x + return p +} + +func (x ExecOutput_Stream) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ExecOutput_Stream) Descriptor() protoreflect.EnumDescriptor { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_enumTypes[0].Descriptor() +} + +func (ExecOutput_Stream) Type() protoreflect.EnumType { + return &file_internal_mcp_mcpcontainer_mcpcontainer_proto_enumTypes[0] +} + +func (x ExecOutput_Stream) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use ExecOutput_Stream.Descriptor instead. +func (ExecOutput_Stream) EnumDescriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{8, 0} +} + +type ReadFileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + LineOffset int32 `protobuf:"varint,2,opt,name=line_offset,json=lineOffset,proto3" json:"line_offset,omitempty"` + NLines int32 `protobuf:"varint,3,opt,name=n_lines,json=nLines,proto3" json:"n_lines,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReadFileRequest) Reset() { + *x = ReadFileRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReadFileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReadFileRequest) ProtoMessage() {} + +func (x *ReadFileRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReadFileRequest.ProtoReflect.Descriptor instead. +func (*ReadFileRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{0} +} + +func (x *ReadFileRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *ReadFileRequest) GetLineOffset() int32 { + if x != nil { + return x.LineOffset + } + return 0 +} + +func (x *ReadFileRequest) GetNLines() int32 { + if x != nil { + return x.NLines + } + return 0 +} + +type ReadFileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Content string `protobuf:"bytes,1,opt,name=content,proto3" json:"content,omitempty"` + TotalLines int32 `protobuf:"varint,2,opt,name=total_lines,json=totalLines,proto3" json:"total_lines,omitempty"` + Binary bool `protobuf:"varint,3,opt,name=binary,proto3" json:"binary,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReadFileResponse) Reset() { + *x = ReadFileResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReadFileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReadFileResponse) ProtoMessage() {} + +func (x *ReadFileResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReadFileResponse.ProtoReflect.Descriptor instead. +func (*ReadFileResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{1} +} + +func (x *ReadFileResponse) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *ReadFileResponse) GetTotalLines() int32 { + if x != nil { + return x.TotalLines + } + return 0 +} + +func (x *ReadFileResponse) GetBinary() bool { + if x != nil { + return x.Binary + } + return false +} + +type WriteFileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Content []byte `protobuf:"bytes,2,opt,name=content,proto3" json:"content,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WriteFileRequest) Reset() { + *x = WriteFileRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WriteFileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WriteFileRequest) ProtoMessage() {} + +func (x *WriteFileRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WriteFileRequest.ProtoReflect.Descriptor instead. +func (*WriteFileRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{2} +} + +func (x *WriteFileRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *WriteFileRequest) GetContent() []byte { + if x != nil { + return x.Content + } + return nil +} + +type WriteFileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WriteFileResponse) Reset() { + *x = WriteFileResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WriteFileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WriteFileResponse) ProtoMessage() {} + +func (x *WriteFileResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WriteFileResponse.ProtoReflect.Descriptor instead. +func (*WriteFileResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{3} +} + +type ListDirRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Recursive bool `protobuf:"varint,2,opt,name=recursive,proto3" json:"recursive,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListDirRequest) Reset() { + *x = ListDirRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListDirRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListDirRequest) ProtoMessage() {} + +func (x *ListDirRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListDirRequest.ProtoReflect.Descriptor instead. +func (*ListDirRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{4} +} + +func (x *ListDirRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *ListDirRequest) GetRecursive() bool { + if x != nil { + return x.Recursive + } + return false +} + +type FileEntry struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + IsDir bool `protobuf:"varint,2,opt,name=is_dir,json=isDir,proto3" json:"is_dir,omitempty"` + Size int64 `protobuf:"varint,3,opt,name=size,proto3" json:"size,omitempty"` + Mode string `protobuf:"bytes,4,opt,name=mode,proto3" json:"mode,omitempty"` + ModTime string `protobuf:"bytes,5,opt,name=mod_time,json=modTime,proto3" json:"mod_time,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FileEntry) Reset() { + *x = FileEntry{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FileEntry) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FileEntry) ProtoMessage() {} + +func (x *FileEntry) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FileEntry.ProtoReflect.Descriptor instead. +func (*FileEntry) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{5} +} + +func (x *FileEntry) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *FileEntry) GetIsDir() bool { + if x != nil { + return x.IsDir + } + return false +} + +func (x *FileEntry) GetSize() int64 { + if x != nil { + return x.Size + } + return 0 +} + +func (x *FileEntry) GetMode() string { + if x != nil { + return x.Mode + } + return "" +} + +func (x *FileEntry) GetModTime() string { + if x != nil { + return x.ModTime + } + return "" +} + +type ListDirResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Entries []*FileEntry `protobuf:"bytes,1,rep,name=entries,proto3" json:"entries,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListDirResponse) Reset() { + *x = ListDirResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListDirResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListDirResponse) ProtoMessage() {} + +func (x *ListDirResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListDirResponse.ProtoReflect.Descriptor instead. +func (*ListDirResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{6} +} + +func (x *ListDirResponse) GetEntries() []*FileEntry { + if x != nil { + return x.Entries + } + return nil +} + +type ExecInput struct { + state protoimpl.MessageState `protogen:"open.v1"` + Command string `protobuf:"bytes,1,opt,name=command,proto3" json:"command,omitempty"` + WorkDir string `protobuf:"bytes,2,opt,name=work_dir,json=workDir,proto3" json:"work_dir,omitempty"` + Env []string `protobuf:"bytes,3,rep,name=env,proto3" json:"env,omitempty"` + TimeoutSeconds int32 `protobuf:"varint,4,opt,name=timeout_seconds,json=timeoutSeconds,proto3" json:"timeout_seconds,omitempty"` + StdinData []byte `protobuf:"bytes,5,opt,name=stdin_data,json=stdinData,proto3" json:"stdin_data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecInput) Reset() { + *x = ExecInput{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecInput) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecInput) ProtoMessage() {} + +func (x *ExecInput) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecInput.ProtoReflect.Descriptor instead. +func (*ExecInput) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{7} +} + +func (x *ExecInput) GetCommand() string { + if x != nil { + return x.Command + } + return "" +} + +func (x *ExecInput) GetWorkDir() string { + if x != nil { + return x.WorkDir + } + return "" +} + +func (x *ExecInput) GetEnv() []string { + if x != nil { + return x.Env + } + return nil +} + +func (x *ExecInput) GetTimeoutSeconds() int32 { + if x != nil { + return x.TimeoutSeconds + } + return 0 +} + +func (x *ExecInput) GetStdinData() []byte { + if x != nil { + return x.StdinData + } + return nil +} + +type ExecOutput struct { + state protoimpl.MessageState `protogen:"open.v1"` + Stream ExecOutput_Stream `protobuf:"varint,1,opt,name=stream,proto3,enum=mcpcontainer.ExecOutput_Stream" json:"stream,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + ExitCode int32 `protobuf:"varint,3,opt,name=exit_code,json=exitCode,proto3" json:"exit_code,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecOutput) Reset() { + *x = ExecOutput{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecOutput) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecOutput) ProtoMessage() {} + +func (x *ExecOutput) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecOutput.ProtoReflect.Descriptor instead. +func (*ExecOutput) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{8} +} + +func (x *ExecOutput) GetStream() ExecOutput_Stream { + if x != nil { + return x.Stream + } + return ExecOutput_STDOUT +} + +func (x *ExecOutput) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +func (x *ExecOutput) GetExitCode() int32 { + if x != nil { + return x.ExitCode + } + return 0 +} + +type ReadRawRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReadRawRequest) Reset() { + *x = ReadRawRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReadRawRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReadRawRequest) ProtoMessage() {} + +func (x *ReadRawRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReadRawRequest.ProtoReflect.Descriptor instead. +func (*ReadRawRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{9} +} + +func (x *ReadRawRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +type DataChunk struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DataChunk) Reset() { + *x = DataChunk{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DataChunk) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DataChunk) ProtoMessage() {} + +func (x *DataChunk) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DataChunk.ProtoReflect.Descriptor instead. +func (*DataChunk) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{10} +} + +func (x *DataChunk) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type WriteRawChunk struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WriteRawChunk) Reset() { + *x = WriteRawChunk{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WriteRawChunk) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WriteRawChunk) ProtoMessage() {} + +func (x *WriteRawChunk) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WriteRawChunk.ProtoReflect.Descriptor instead. +func (*WriteRawChunk) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{11} +} + +func (x *WriteRawChunk) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *WriteRawChunk) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type WriteRawResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + BytesWritten int64 `protobuf:"varint,1,opt,name=bytes_written,json=bytesWritten,proto3" json:"bytes_written,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WriteRawResponse) Reset() { + *x = WriteRawResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WriteRawResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WriteRawResponse) ProtoMessage() {} + +func (x *WriteRawResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WriteRawResponse.ProtoReflect.Descriptor instead. +func (*WriteRawResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{12} +} + +func (x *WriteRawResponse) GetBytesWritten() int64 { + if x != nil { + return x.BytesWritten + } + return 0 +} + +type DeleteFileRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Recursive bool `protobuf:"varint,2,opt,name=recursive,proto3" json:"recursive,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteFileRequest) Reset() { + *x = DeleteFileRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteFileRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteFileRequest) ProtoMessage() {} + +func (x *DeleteFileRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteFileRequest.ProtoReflect.Descriptor instead. +func (*DeleteFileRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{13} +} + +func (x *DeleteFileRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +func (x *DeleteFileRequest) GetRecursive() bool { + if x != nil { + return x.Recursive + } + return false +} + +type DeleteFileResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteFileResponse) Reset() { + *x = DeleteFileResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteFileResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteFileResponse) ProtoMessage() {} + +func (x *DeleteFileResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteFileResponse.ProtoReflect.Descriptor instead. +func (*DeleteFileResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{14} +} + +type StatRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StatRequest) Reset() { + *x = StatRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StatRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StatRequest) ProtoMessage() {} + +func (x *StatRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StatRequest.ProtoReflect.Descriptor instead. +func (*StatRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{15} +} + +func (x *StatRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +type StatResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Entry *FileEntry `protobuf:"bytes,1,opt,name=entry,proto3" json:"entry,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StatResponse) Reset() { + *x = StatResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StatResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StatResponse) ProtoMessage() {} + +func (x *StatResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[16] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StatResponse.ProtoReflect.Descriptor instead. +func (*StatResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{16} +} + +func (x *StatResponse) GetEntry() *FileEntry { + if x != nil { + return x.Entry + } + return nil +} + +type MkdirRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MkdirRequest) Reset() { + *x = MkdirRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MkdirRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MkdirRequest) ProtoMessage() {} + +func (x *MkdirRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[17] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MkdirRequest.ProtoReflect.Descriptor instead. +func (*MkdirRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{17} +} + +func (x *MkdirRequest) GetPath() string { + if x != nil { + return x.Path + } + return "" +} + +type MkdirResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MkdirResponse) Reset() { + *x = MkdirResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MkdirResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MkdirResponse) ProtoMessage() {} + +func (x *MkdirResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[18] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MkdirResponse.ProtoReflect.Descriptor instead. +func (*MkdirResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{18} +} + +type RenameRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + OldPath string `protobuf:"bytes,1,opt,name=old_path,json=oldPath,proto3" json:"old_path,omitempty"` + NewPath string `protobuf:"bytes,2,opt,name=new_path,json=newPath,proto3" json:"new_path,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RenameRequest) Reset() { + *x = RenameRequest{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RenameRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RenameRequest) ProtoMessage() {} + +func (x *RenameRequest) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[19] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RenameRequest.ProtoReflect.Descriptor instead. +func (*RenameRequest) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{19} +} + +func (x *RenameRequest) GetOldPath() string { + if x != nil { + return x.OldPath + } + return "" +} + +func (x *RenameRequest) GetNewPath() string { + if x != nil { + return x.NewPath + } + return "" +} + +type RenameResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RenameResponse) Reset() { + *x = RenameResponse{} + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RenameResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RenameResponse) ProtoMessage() {} + +func (x *RenameResponse) ProtoReflect() protoreflect.Message { + mi := &file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes[20] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RenameResponse.ProtoReflect.Descriptor instead. +func (*RenameResponse) Descriptor() ([]byte, []int) { + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP(), []int{20} +} + +var File_internal_mcp_mcpcontainer_mcpcontainer_proto protoreflect.FileDescriptor + +const file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDesc = "" + + "\n" + + ",internal/mcp/mcpcontainer/mcpcontainer.proto\x12\fmcpcontainer\"_\n" + + "\x0fReadFileRequest\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\x12\x1f\n" + + "\vline_offset\x18\x02 \x01(\x05R\n" + + "lineOffset\x12\x17\n" + + "\an_lines\x18\x03 \x01(\x05R\x06nLines\"e\n" + + "\x10ReadFileResponse\x12\x18\n" + + "\acontent\x18\x01 \x01(\tR\acontent\x12\x1f\n" + + "\vtotal_lines\x18\x02 \x01(\x05R\n" + + "totalLines\x12\x16\n" + + "\x06binary\x18\x03 \x01(\bR\x06binary\"@\n" + + "\x10WriteFileRequest\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\x12\x18\n" + + "\acontent\x18\x02 \x01(\fR\acontent\"\x13\n" + + "\x11WriteFileResponse\"B\n" + + "\x0eListDirRequest\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\x12\x1c\n" + + "\trecursive\x18\x02 \x01(\bR\trecursive\"y\n" + + "\tFileEntry\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\x12\x15\n" + + "\x06is_dir\x18\x02 \x01(\bR\x05isDir\x12\x12\n" + + "\x04size\x18\x03 \x01(\x03R\x04size\x12\x12\n" + + "\x04mode\x18\x04 \x01(\tR\x04mode\x12\x19\n" + + "\bmod_time\x18\x05 \x01(\tR\amodTime\"D\n" + + "\x0fListDirResponse\x121\n" + + "\aentries\x18\x01 \x03(\v2\x17.mcpcontainer.FileEntryR\aentries\"\x9a\x01\n" + + "\tExecInput\x12\x18\n" + + "\acommand\x18\x01 \x01(\tR\acommand\x12\x19\n" + + "\bwork_dir\x18\x02 \x01(\tR\aworkDir\x12\x10\n" + + "\x03env\x18\x03 \x03(\tR\x03env\x12'\n" + + "\x0ftimeout_seconds\x18\x04 \x01(\x05R\x0etimeoutSeconds\x12\x1d\n" + + "\n" + + "stdin_data\x18\x05 \x01(\fR\tstdinData\"\xa2\x01\n" + + "\n" + + "ExecOutput\x127\n" + + "\x06stream\x18\x01 \x01(\x0e2\x1f.mcpcontainer.ExecOutput.StreamR\x06stream\x12\x12\n" + + "\x04data\x18\x02 \x01(\fR\x04data\x12\x1b\n" + + "\texit_code\x18\x03 \x01(\x05R\bexitCode\"*\n" + + "\x06Stream\x12\n" + + "\n" + + "\x06STDOUT\x10\x00\x12\n" + + "\n" + + "\x06STDERR\x10\x01\x12\b\n" + + "\x04EXIT\x10\x02\"$\n" + + "\x0eReadRawRequest\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\"\x1f\n" + + "\tDataChunk\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\"7\n" + + "\rWriteRawChunk\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\x12\x12\n" + + "\x04data\x18\x02 \x01(\fR\x04data\"7\n" + + "\x10WriteRawResponse\x12#\n" + + "\rbytes_written\x18\x01 \x01(\x03R\fbytesWritten\"E\n" + + "\x11DeleteFileRequest\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\x12\x1c\n" + + "\trecursive\x18\x02 \x01(\bR\trecursive\"\x14\n" + + "\x12DeleteFileResponse\"!\n" + + "\vStatRequest\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\"=\n" + + "\fStatResponse\x12-\n" + + "\x05entry\x18\x01 \x01(\v2\x17.mcpcontainer.FileEntryR\x05entry\"\"\n" + + "\fMkdirRequest\x12\x12\n" + + "\x04path\x18\x01 \x01(\tR\x04path\"\x0f\n" + + "\rMkdirResponse\"E\n" + + "\rRenameRequest\x12\x19\n" + + "\bold_path\x18\x01 \x01(\tR\aoldPath\x12\x19\n" + + "\bnew_path\x18\x02 \x01(\tR\anewPath\"\x10\n" + + "\x0eRenameResponse2\xd8\x05\n" + + "\x10ContainerService\x12I\n" + + "\bReadFile\x12\x1d.mcpcontainer.ReadFileRequest\x1a\x1e.mcpcontainer.ReadFileResponse\x12L\n" + + "\tWriteFile\x12\x1e.mcpcontainer.WriteFileRequest\x1a\x1f.mcpcontainer.WriteFileResponse\x12F\n" + + "\aListDir\x12\x1c.mcpcontainer.ListDirRequest\x1a\x1d.mcpcontainer.ListDirResponse\x12=\n" + + "\x04Stat\x12\x19.mcpcontainer.StatRequest\x1a\x1a.mcpcontainer.StatResponse\x12@\n" + + "\x05Mkdir\x12\x1a.mcpcontainer.MkdirRequest\x1a\x1b.mcpcontainer.MkdirResponse\x12C\n" + + "\x06Rename\x12\x1b.mcpcontainer.RenameRequest\x1a\x1c.mcpcontainer.RenameResponse\x12=\n" + + "\x04Exec\x12\x17.mcpcontainer.ExecInput\x1a\x18.mcpcontainer.ExecOutput(\x010\x01\x12B\n" + + "\aReadRaw\x12\x1c.mcpcontainer.ReadRawRequest\x1a\x17.mcpcontainer.DataChunk0\x01\x12I\n" + + "\bWriteRaw\x12\x1b.mcpcontainer.WriteRawChunk\x1a\x1e.mcpcontainer.WriteRawResponse(\x01\x12O\n" + + "\n" + + "DeleteFile\x12\x1f.mcpcontainer.DeleteFileRequest\x1a .mcpcontainer.DeleteFileResponseB4Z2github.com/memohai/memoh/internal/mcp/mcpcontainerb\x06proto3" + +var ( + file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescOnce sync.Once + file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescData []byte +) + +func file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescGZIP() []byte { + file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescOnce.Do(func() { + file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDesc), len(file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDesc))) + }) + return file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDescData +} + +var file_internal_mcp_mcpcontainer_mcpcontainer_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes = make([]protoimpl.MessageInfo, 21) +var file_internal_mcp_mcpcontainer_mcpcontainer_proto_goTypes = []any{ + (ExecOutput_Stream)(0), // 0: mcpcontainer.ExecOutput.Stream + (*ReadFileRequest)(nil), // 1: mcpcontainer.ReadFileRequest + (*ReadFileResponse)(nil), // 2: mcpcontainer.ReadFileResponse + (*WriteFileRequest)(nil), // 3: mcpcontainer.WriteFileRequest + (*WriteFileResponse)(nil), // 4: mcpcontainer.WriteFileResponse + (*ListDirRequest)(nil), // 5: mcpcontainer.ListDirRequest + (*FileEntry)(nil), // 6: mcpcontainer.FileEntry + (*ListDirResponse)(nil), // 7: mcpcontainer.ListDirResponse + (*ExecInput)(nil), // 8: mcpcontainer.ExecInput + (*ExecOutput)(nil), // 9: mcpcontainer.ExecOutput + (*ReadRawRequest)(nil), // 10: mcpcontainer.ReadRawRequest + (*DataChunk)(nil), // 11: mcpcontainer.DataChunk + (*WriteRawChunk)(nil), // 12: mcpcontainer.WriteRawChunk + (*WriteRawResponse)(nil), // 13: mcpcontainer.WriteRawResponse + (*DeleteFileRequest)(nil), // 14: mcpcontainer.DeleteFileRequest + (*DeleteFileResponse)(nil), // 15: mcpcontainer.DeleteFileResponse + (*StatRequest)(nil), // 16: mcpcontainer.StatRequest + (*StatResponse)(nil), // 17: mcpcontainer.StatResponse + (*MkdirRequest)(nil), // 18: mcpcontainer.MkdirRequest + (*MkdirResponse)(nil), // 19: mcpcontainer.MkdirResponse + (*RenameRequest)(nil), // 20: mcpcontainer.RenameRequest + (*RenameResponse)(nil), // 21: mcpcontainer.RenameResponse +} +var file_internal_mcp_mcpcontainer_mcpcontainer_proto_depIdxs = []int32{ + 6, // 0: mcpcontainer.ListDirResponse.entries:type_name -> mcpcontainer.FileEntry + 0, // 1: mcpcontainer.ExecOutput.stream:type_name -> mcpcontainer.ExecOutput.Stream + 6, // 2: mcpcontainer.StatResponse.entry:type_name -> mcpcontainer.FileEntry + 1, // 3: mcpcontainer.ContainerService.ReadFile:input_type -> mcpcontainer.ReadFileRequest + 3, // 4: mcpcontainer.ContainerService.WriteFile:input_type -> mcpcontainer.WriteFileRequest + 5, // 5: mcpcontainer.ContainerService.ListDir:input_type -> mcpcontainer.ListDirRequest + 16, // 6: mcpcontainer.ContainerService.Stat:input_type -> mcpcontainer.StatRequest + 18, // 7: mcpcontainer.ContainerService.Mkdir:input_type -> mcpcontainer.MkdirRequest + 20, // 8: mcpcontainer.ContainerService.Rename:input_type -> mcpcontainer.RenameRequest + 8, // 9: mcpcontainer.ContainerService.Exec:input_type -> mcpcontainer.ExecInput + 10, // 10: mcpcontainer.ContainerService.ReadRaw:input_type -> mcpcontainer.ReadRawRequest + 12, // 11: mcpcontainer.ContainerService.WriteRaw:input_type -> mcpcontainer.WriteRawChunk + 14, // 12: mcpcontainer.ContainerService.DeleteFile:input_type -> mcpcontainer.DeleteFileRequest + 2, // 13: mcpcontainer.ContainerService.ReadFile:output_type -> mcpcontainer.ReadFileResponse + 4, // 14: mcpcontainer.ContainerService.WriteFile:output_type -> mcpcontainer.WriteFileResponse + 7, // 15: mcpcontainer.ContainerService.ListDir:output_type -> mcpcontainer.ListDirResponse + 17, // 16: mcpcontainer.ContainerService.Stat:output_type -> mcpcontainer.StatResponse + 19, // 17: mcpcontainer.ContainerService.Mkdir:output_type -> mcpcontainer.MkdirResponse + 21, // 18: mcpcontainer.ContainerService.Rename:output_type -> mcpcontainer.RenameResponse + 9, // 19: mcpcontainer.ContainerService.Exec:output_type -> mcpcontainer.ExecOutput + 11, // 20: mcpcontainer.ContainerService.ReadRaw:output_type -> mcpcontainer.DataChunk + 13, // 21: mcpcontainer.ContainerService.WriteRaw:output_type -> mcpcontainer.WriteRawResponse + 15, // 22: mcpcontainer.ContainerService.DeleteFile:output_type -> mcpcontainer.DeleteFileResponse + 13, // [13:23] is the sub-list for method output_type + 3, // [3:13] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_internal_mcp_mcpcontainer_mcpcontainer_proto_init() } +func file_internal_mcp_mcpcontainer_mcpcontainer_proto_init() { + if File_internal_mcp_mcpcontainer_mcpcontainer_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDesc), len(file_internal_mcp_mcpcontainer_mcpcontainer_proto_rawDesc)), + NumEnums: 1, + NumMessages: 21, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_internal_mcp_mcpcontainer_mcpcontainer_proto_goTypes, + DependencyIndexes: file_internal_mcp_mcpcontainer_mcpcontainer_proto_depIdxs, + EnumInfos: file_internal_mcp_mcpcontainer_mcpcontainer_proto_enumTypes, + MessageInfos: file_internal_mcp_mcpcontainer_mcpcontainer_proto_msgTypes, + }.Build() + File_internal_mcp_mcpcontainer_mcpcontainer_proto = out.File + file_internal_mcp_mcpcontainer_mcpcontainer_proto_goTypes = nil + file_internal_mcp_mcpcontainer_mcpcontainer_proto_depIdxs = nil +} diff --git a/internal/mcp/mcpcontainer/mcpcontainer.proto b/internal/mcp/mcpcontainer/mcpcontainer.proto new file mode 100644 index 00000000..9e81a7a6 --- /dev/null +++ b/internal/mcp/mcpcontainer/mcpcontainer.proto @@ -0,0 +1,118 @@ +syntax = "proto3"; + +package mcpcontainer; + +option go_package = "github.com/memohai/memoh/internal/mcp/mcpcontainer"; + +service ContainerService { + rpc ReadFile(ReadFileRequest) returns (ReadFileResponse); + rpc WriteFile(WriteFileRequest) returns (WriteFileResponse); + rpc ListDir(ListDirRequest) returns (ListDirResponse); + rpc Stat(StatRequest) returns (StatResponse); + rpc Mkdir(MkdirRequest) returns (MkdirResponse); + rpc Rename(RenameRequest) returns (RenameResponse); + rpc Exec(stream ExecInput) returns (stream ExecOutput); + rpc ReadRaw(ReadRawRequest) returns (stream DataChunk); + rpc WriteRaw(stream WriteRawChunk) returns (WriteRawResponse); + rpc DeleteFile(DeleteFileRequest) returns (DeleteFileResponse); +} + +message ReadFileRequest { + string path = 1; + int32 line_offset = 2; + int32 n_lines = 3; +} + +message ReadFileResponse { + string content = 1; + int32 total_lines = 2; + bool binary = 3; +} + +message WriteFileRequest { + string path = 1; + bytes content = 2; +} + +message WriteFileResponse {} + +message ListDirRequest { + string path = 1; + bool recursive = 2; +} + +message FileEntry { + string path = 1; + bool is_dir = 2; + int64 size = 3; + string mode = 4; + string mod_time = 5; +} + +message ListDirResponse { + repeated FileEntry entries = 1; +} + +message ExecInput { + string command = 1; + string work_dir = 2; + repeated string env = 3; + int32 timeout_seconds = 4; + bytes stdin_data = 5; +} + +message ExecOutput { + enum Stream { + STDOUT = 0; + STDERR = 1; + EXIT = 2; + } + Stream stream = 1; + bytes data = 2; + int32 exit_code = 3; +} + +message ReadRawRequest { + string path = 1; +} + +message DataChunk { + bytes data = 1; +} + +message WriteRawChunk { + string path = 1; + bytes data = 2; +} + +message WriteRawResponse { + int64 bytes_written = 1; +} + +message DeleteFileRequest { + string path = 1; + bool recursive = 2; +} + +message DeleteFileResponse {} + +message StatRequest { + string path = 1; +} + +message StatResponse { + FileEntry entry = 1; +} + +message MkdirRequest { + string path = 1; +} + +message MkdirResponse {} + +message RenameRequest { + string old_path = 1; + string new_path = 2; +} + +message RenameResponse {} diff --git a/internal/mcp/mcpcontainer/mcpcontainer_grpc.pb.go b/internal/mcp/mcpcontainer/mcpcontainer_grpc.pb.go new file mode 100644 index 00000000..39d23f7e --- /dev/null +++ b/internal/mcp/mcpcontainer/mcpcontainer_grpc.pb.go @@ -0,0 +1,454 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v4.25.3 +// source: internal/mcp/mcpcontainer/mcpcontainer.proto + +package mcpcontainer + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ContainerService_ReadFile_FullMethodName = "/mcpcontainer.ContainerService/ReadFile" + ContainerService_WriteFile_FullMethodName = "/mcpcontainer.ContainerService/WriteFile" + ContainerService_ListDir_FullMethodName = "/mcpcontainer.ContainerService/ListDir" + ContainerService_Stat_FullMethodName = "/mcpcontainer.ContainerService/Stat" + ContainerService_Mkdir_FullMethodName = "/mcpcontainer.ContainerService/Mkdir" + ContainerService_Rename_FullMethodName = "/mcpcontainer.ContainerService/Rename" + ContainerService_Exec_FullMethodName = "/mcpcontainer.ContainerService/Exec" + ContainerService_ReadRaw_FullMethodName = "/mcpcontainer.ContainerService/ReadRaw" + ContainerService_WriteRaw_FullMethodName = "/mcpcontainer.ContainerService/WriteRaw" + ContainerService_DeleteFile_FullMethodName = "/mcpcontainer.ContainerService/DeleteFile" +) + +// ContainerServiceClient is the client API for ContainerService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ContainerServiceClient interface { + ReadFile(ctx context.Context, in *ReadFileRequest, opts ...grpc.CallOption) (*ReadFileResponse, error) + WriteFile(ctx context.Context, in *WriteFileRequest, opts ...grpc.CallOption) (*WriteFileResponse, error) + ListDir(ctx context.Context, in *ListDirRequest, opts ...grpc.CallOption) (*ListDirResponse, error) + Stat(ctx context.Context, in *StatRequest, opts ...grpc.CallOption) (*StatResponse, error) + Mkdir(ctx context.Context, in *MkdirRequest, opts ...grpc.CallOption) (*MkdirResponse, error) + Rename(ctx context.Context, in *RenameRequest, opts ...grpc.CallOption) (*RenameResponse, error) + Exec(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ExecInput, ExecOutput], error) + ReadRaw(ctx context.Context, in *ReadRawRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[DataChunk], error) + WriteRaw(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[WriteRawChunk, WriteRawResponse], error) + DeleteFile(ctx context.Context, in *DeleteFileRequest, opts ...grpc.CallOption) (*DeleteFileResponse, error) +} + +type containerServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewContainerServiceClient(cc grpc.ClientConnInterface) ContainerServiceClient { + return &containerServiceClient{cc} +} + +func (c *containerServiceClient) ReadFile(ctx context.Context, in *ReadFileRequest, opts ...grpc.CallOption) (*ReadFileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ReadFileResponse) + err := c.cc.Invoke(ctx, ContainerService_ReadFile_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *containerServiceClient) WriteFile(ctx context.Context, in *WriteFileRequest, opts ...grpc.CallOption) (*WriteFileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(WriteFileResponse) + err := c.cc.Invoke(ctx, ContainerService_WriteFile_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *containerServiceClient) ListDir(ctx context.Context, in *ListDirRequest, opts ...grpc.CallOption) (*ListDirResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ListDirResponse) + err := c.cc.Invoke(ctx, ContainerService_ListDir_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *containerServiceClient) Stat(ctx context.Context, in *StatRequest, opts ...grpc.CallOption) (*StatResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(StatResponse) + err := c.cc.Invoke(ctx, ContainerService_Stat_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *containerServiceClient) Mkdir(ctx context.Context, in *MkdirRequest, opts ...grpc.CallOption) (*MkdirResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(MkdirResponse) + err := c.cc.Invoke(ctx, ContainerService_Mkdir_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *containerServiceClient) Rename(ctx context.Context, in *RenameRequest, opts ...grpc.CallOption) (*RenameResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RenameResponse) + err := c.cc.Invoke(ctx, ContainerService_Rename_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *containerServiceClient) Exec(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ExecInput, ExecOutput], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ContainerService_ServiceDesc.Streams[0], ContainerService_Exec_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[ExecInput, ExecOutput]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ContainerService_ExecClient = grpc.BidiStreamingClient[ExecInput, ExecOutput] + +func (c *containerServiceClient) ReadRaw(ctx context.Context, in *ReadRawRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[DataChunk], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ContainerService_ServiceDesc.Streams[1], ContainerService_ReadRaw_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[ReadRawRequest, DataChunk]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ContainerService_ReadRawClient = grpc.ServerStreamingClient[DataChunk] + +func (c *containerServiceClient) WriteRaw(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[WriteRawChunk, WriteRawResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ContainerService_ServiceDesc.Streams[2], ContainerService_WriteRaw_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[WriteRawChunk, WriteRawResponse]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ContainerService_WriteRawClient = grpc.ClientStreamingClient[WriteRawChunk, WriteRawResponse] + +func (c *containerServiceClient) DeleteFile(ctx context.Context, in *DeleteFileRequest, opts ...grpc.CallOption) (*DeleteFileResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DeleteFileResponse) + err := c.cc.Invoke(ctx, ContainerService_DeleteFile_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ContainerServiceServer is the server API for ContainerService service. +// All implementations must embed UnimplementedContainerServiceServer +// for forward compatibility. +type ContainerServiceServer interface { + ReadFile(context.Context, *ReadFileRequest) (*ReadFileResponse, error) + WriteFile(context.Context, *WriteFileRequest) (*WriteFileResponse, error) + ListDir(context.Context, *ListDirRequest) (*ListDirResponse, error) + Stat(context.Context, *StatRequest) (*StatResponse, error) + Mkdir(context.Context, *MkdirRequest) (*MkdirResponse, error) + Rename(context.Context, *RenameRequest) (*RenameResponse, error) + Exec(grpc.BidiStreamingServer[ExecInput, ExecOutput]) error + ReadRaw(*ReadRawRequest, grpc.ServerStreamingServer[DataChunk]) error + WriteRaw(grpc.ClientStreamingServer[WriteRawChunk, WriteRawResponse]) error + DeleteFile(context.Context, *DeleteFileRequest) (*DeleteFileResponse, error) + mustEmbedUnimplementedContainerServiceServer() +} + +// UnimplementedContainerServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedContainerServiceServer struct{} + +func (UnimplementedContainerServiceServer) ReadFile(context.Context, *ReadFileRequest) (*ReadFileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ReadFile not implemented") +} +func (UnimplementedContainerServiceServer) WriteFile(context.Context, *WriteFileRequest) (*WriteFileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method WriteFile not implemented") +} +func (UnimplementedContainerServiceServer) ListDir(context.Context, *ListDirRequest) (*ListDirResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListDir not implemented") +} +func (UnimplementedContainerServiceServer) Stat(context.Context, *StatRequest) (*StatResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Stat not implemented") +} +func (UnimplementedContainerServiceServer) Mkdir(context.Context, *MkdirRequest) (*MkdirResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Mkdir not implemented") +} +func (UnimplementedContainerServiceServer) Rename(context.Context, *RenameRequest) (*RenameResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Rename not implemented") +} +func (UnimplementedContainerServiceServer) Exec(grpc.BidiStreamingServer[ExecInput, ExecOutput]) error { + return status.Errorf(codes.Unimplemented, "method Exec not implemented") +} +func (UnimplementedContainerServiceServer) ReadRaw(*ReadRawRequest, grpc.ServerStreamingServer[DataChunk]) error { + return status.Errorf(codes.Unimplemented, "method ReadRaw not implemented") +} +func (UnimplementedContainerServiceServer) WriteRaw(grpc.ClientStreamingServer[WriteRawChunk, WriteRawResponse]) error { + return status.Errorf(codes.Unimplemented, "method WriteRaw not implemented") +} +func (UnimplementedContainerServiceServer) DeleteFile(context.Context, *DeleteFileRequest) (*DeleteFileResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeleteFile not implemented") +} +func (UnimplementedContainerServiceServer) mustEmbedUnimplementedContainerServiceServer() {} +func (UnimplementedContainerServiceServer) testEmbeddedByValue() {} + +// UnsafeContainerServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ContainerServiceServer will +// result in compilation errors. +type UnsafeContainerServiceServer interface { + mustEmbedUnimplementedContainerServiceServer() +} + +func RegisterContainerServiceServer(s grpc.ServiceRegistrar, srv ContainerServiceServer) { + // If the following call pancis, it indicates UnimplementedContainerServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ContainerService_ServiceDesc, srv) +} + +func _ContainerService_ReadFile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ReadFileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContainerServiceServer).ReadFile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContainerService_ReadFile_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContainerServiceServer).ReadFile(ctx, req.(*ReadFileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ContainerService_WriteFile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(WriteFileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContainerServiceServer).WriteFile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContainerService_WriteFile_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContainerServiceServer).WriteFile(ctx, req.(*WriteFileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ContainerService_ListDir_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListDirRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContainerServiceServer).ListDir(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContainerService_ListDir_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContainerServiceServer).ListDir(ctx, req.(*ListDirRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ContainerService_Stat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StatRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContainerServiceServer).Stat(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContainerService_Stat_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContainerServiceServer).Stat(ctx, req.(*StatRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ContainerService_Mkdir_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(MkdirRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContainerServiceServer).Mkdir(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContainerService_Mkdir_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContainerServiceServer).Mkdir(ctx, req.(*MkdirRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ContainerService_Rename_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RenameRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContainerServiceServer).Rename(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContainerService_Rename_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContainerServiceServer).Rename(ctx, req.(*RenameRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ContainerService_Exec_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ContainerServiceServer).Exec(&grpc.GenericServerStream[ExecInput, ExecOutput]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ContainerService_ExecServer = grpc.BidiStreamingServer[ExecInput, ExecOutput] + +func _ContainerService_ReadRaw_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(ReadRawRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(ContainerServiceServer).ReadRaw(m, &grpc.GenericServerStream[ReadRawRequest, DataChunk]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ContainerService_ReadRawServer = grpc.ServerStreamingServer[DataChunk] + +func _ContainerService_WriteRaw_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ContainerServiceServer).WriteRaw(&grpc.GenericServerStream[WriteRawChunk, WriteRawResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ContainerService_WriteRawServer = grpc.ClientStreamingServer[WriteRawChunk, WriteRawResponse] + +func _ContainerService_DeleteFile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteFileRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContainerServiceServer).DeleteFile(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContainerService_DeleteFile_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContainerServiceServer).DeleteFile(ctx, req.(*DeleteFileRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// ContainerService_ServiceDesc is the grpc.ServiceDesc for ContainerService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ContainerService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "mcpcontainer.ContainerService", + HandlerType: (*ContainerServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ReadFile", + Handler: _ContainerService_ReadFile_Handler, + }, + { + MethodName: "WriteFile", + Handler: _ContainerService_WriteFile_Handler, + }, + { + MethodName: "ListDir", + Handler: _ContainerService_ListDir_Handler, + }, + { + MethodName: "Stat", + Handler: _ContainerService_Stat_Handler, + }, + { + MethodName: "Mkdir", + Handler: _ContainerService_Mkdir_Handler, + }, + { + MethodName: "Rename", + Handler: _ContainerService_Rename_Handler, + }, + { + MethodName: "DeleteFile", + Handler: _ContainerService_DeleteFile_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Exec", + Handler: _ContainerService_Exec_Handler, + ServerStreams: true, + ClientStreams: true, + }, + { + StreamName: "ReadRaw", + Handler: _ContainerService_ReadRaw_Handler, + ServerStreams: true, + }, + { + StreamName: "WriteRaw", + Handler: _ContainerService_WriteRaw_Handler, + ClientStreams: true, + }, + }, + Metadata: "internal/mcp/mcpcontainer/mcpcontainer.proto", +} diff --git a/internal/mcp/migrate.go b/internal/mcp/migrate.go new file mode 100644 index 00000000..b7850395 --- /dev/null +++ b/internal/mcp/migrate.go @@ -0,0 +1,103 @@ +package mcp + +import ( + "context" + "io/fs" + "log/slog" + "os" + "path/filepath" + "strings" + + "github.com/memohai/memoh/internal/mcp/mcpclient" +) + +const migratedSuffix = ".migrated" + +// migrateBindMountData copies bot data from the old host bind-mount directory +// into the container via gRPC, then renames the source to prevent re-migration. +// This is a one-time operation for bots that were created before the switch +// from bind mounts to container-local storage. +func (m *Manager) migrateBindMountData(ctx context.Context, botID string) { + srcDir := filepath.Join(m.dataRoot(), "bots", botID) + migratedDir := srcDir + migratedSuffix + + if _, err := os.Stat(migratedDir); err == nil { + return // already migrated + } + info, err := os.Stat(srcDir) + if err != nil || !info.IsDir() { + return // no old data + } + + // Quick check: is the directory empty? + entries, err := os.ReadDir(srcDir) + if err != nil || len(entries) == 0 { + return + } + + client, err := m.grpcPool.Get(ctx, botID) + if err != nil { + m.logger.Warn("migrate: cannot connect to container", + slog.String("bot_id", botID), slog.Any("error", err)) + return + } + + m.logger.Info("migrating bind-mount data into container", + slog.String("bot_id", botID), slog.String("src", srcDir)) + + var migrated, failed int + err = filepath.WalkDir(srcDir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + // A directory walk error means the entire subtree is skipped by + // WalkDir. Count it as a failure so the src dir is NOT renamed + // and migration is retried on next start. + m.logger.Warn("migrate: walk error", + slog.String("path", path), slog.Any("error", walkErr)) + failed++ + return nil + } + rel, relErr := filepath.Rel(srcDir, path) + if relErr != nil || rel == "." { + return nil + } + if d.IsDir() { + return nil // dirs are created implicitly by WriteFile + } + + if err := copyFileToContainer(ctx, client, path, rel); err != nil { + m.logger.Warn("migrate: copy failed", + slog.String("file", rel), slog.Any("error", err)) + failed++ + return nil + } + migrated++ + return nil + }) + if err != nil { + m.logger.Warn("migrate: walk failed", slog.String("bot_id", botID), slog.Any("error", err)) + } + + m.logger.Info("migration complete", + slog.String("bot_id", botID), + slog.Int("migrated", migrated), + slog.Int("failed", failed)) + + if failed == 0 { + if renameErr := os.Rename(srcDir, migratedDir); renameErr != nil { + m.logger.Warn("migrate: rename src dir failed", + slog.String("src", srcDir), slog.Any("error", renameErr)) + } + } +} + +func copyFileToContainer(ctx context.Context, client *mcpclient.Client, hostPath, containerRelPath string) error { + f, err := os.Open(hostPath) + if err != nil { + return err + } + defer f.Close() + + containerRelPath = strings.ReplaceAll(containerRelPath, string(filepath.Separator), "/") + _, err = client.WriteRaw(ctx, containerRelPath, f) + return err +} diff --git a/internal/mcp/providers/container/fsops.go b/internal/mcp/providers/container/fsops.go index ee4fccb3..1bfdbdf3 100644 --- a/internal/mcp/providers/container/fsops.go +++ b/internal/mcp/providers/container/fsops.go @@ -1,148 +1,11 @@ package container import ( - "context" - "encoding/base64" "fmt" - "path" - "strconv" "strings" - "time" "unicode" - - mcpgw "github.com/memohai/memoh/internal/mcp" ) -// FileEntry represents a filesystem entry returned by ExecList. -type FileEntry struct { - Path string - IsDir bool - Size int64 - Mode uint32 - ModTime time.Time -} - -func wrapWithCd(workDir, script string) string { - if workDir == "" { - return script - } - return "cd " + ShellQuote(workDir) + " && " + script -} - -// ExecRead reads a file inside the container via cat. -func ExecRead(ctx context.Context, runner ExecRunner, botID, workDir, filePath string) (string, error) { - result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ - BotID: botID, - Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, "cat "+ShellQuote(filePath))}, - WorkDir: workDir, - }) - if err != nil { - return "", err - } - if result.ExitCode != 0 { - return "", fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) - } - return result.Stdout, nil -} - -// ExecWrite writes content to a file inside the container using base64 encoding -// to avoid shell escaping issues. -func ExecWrite(ctx context.Context, runner ExecRunner, botID, workDir, filePath, content string) error { - encoded := base64.StdEncoding.EncodeToString([]byte(content)) - dir := path.Dir(filePath) - script := fmt.Sprintf("mkdir -p %s && echo %s | base64 -d > %s", - ShellQuote(dir), ShellQuote(encoded), ShellQuote(filePath)) - result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ - BotID: botID, - Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, script)}, - WorkDir: workDir, - }) - if err != nil { - return err - } - if result.ExitCode != 0 { - return fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) - } - return nil -} - -// ExecList lists directory entries inside the container via find + stat. -// Output format per line: |||| -func ExecList(ctx context.Context, runner ExecRunner, botID, workDir, dirPath string, recursive bool) ([]FileEntry, error) { - depthFlag := "-maxdepth 1" - if recursive { - depthFlag = "" - } - // Use find to get entries, skip the root dir itself, then stat each entry. - // busybox stat -c format: %n=name, %F=type, %s=size, %a=octal mode, %Y=mtime epoch - script := fmt.Sprintf( - `find %s %s ! -path %s -exec stat -c '%%n|%%F|%%s|%%a|%%Y' {} \;`, - ShellQuote(dirPath), depthFlag, ShellQuote(dirPath), - ) - result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ - BotID: botID, - Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, script)}, - WorkDir: workDir, - }) - if err != nil { - return nil, err - } - if result.ExitCode != 0 { - return nil, fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) - } - return parseStatOutput(result.Stdout, dirPath), nil -} - -// parseStatOutput parses lines of "fullpath|type|size|mode|mtime" into FileEntry slices. -func parseStatOutput(output, basePath string) []FileEntry { - lines := strings.Split(strings.TrimSpace(output), "\n") - entries := make([]FileEntry, 0, len(lines)) - // Normalize base path for computing relative paths. - base := strings.TrimSuffix(basePath, "/") - if base == "" || base == "." { - base = "" - } - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" { - continue - } - parts := strings.SplitN(line, "|", 5) - if len(parts) < 5 { - continue - } - fullPath := parts[0] - fileType := parts[1] - sizeStr := parts[2] - modeStr := parts[3] - mtimeStr := parts[4] - - // Compute relative path from base. - rel := fullPath - if base != "" { - rel = strings.TrimPrefix(fullPath, base+"/") - } - if rel == "" || rel == "." { - continue - } - - isDir := strings.Contains(fileType, "directory") - size, _ := strconv.ParseInt(sizeStr, 10, 64) - mode64, _ := strconv.ParseUint(modeStr, 8, 32) - mtimeEpoch, _ := strconv.ParseInt(mtimeStr, 10, 64) - modTime := time.Unix(mtimeEpoch, 0) - - entries = append(entries, FileEntry{ - Path: rel, - IsDir: isDir, - Size: size, - Mode: uint32(mode64), - ModTime: modTime, - }) - } - return entries -} - // applyEdit performs the fuzzy text replacement logic on raw file content. // Returns the updated content or an error. func applyEdit(raw, filePath, oldText, newText string) (string, error) { @@ -200,7 +63,7 @@ func ShellQuote(s string) string { return b.String() } -// ---------- fuzzy matching helpers (pure string processing, unchanged) ---------- +// ---------- fuzzy matching helpers ---------- type fuzzyMatchResult struct { Found bool diff --git a/internal/mcp/providers/container/fsops_test.go b/internal/mcp/providers/container/fsops_test.go index 3341bccb..c4f36cc6 100644 --- a/internal/mcp/providers/container/fsops_test.go +++ b/internal/mcp/providers/container/fsops_test.go @@ -20,54 +20,6 @@ func TestShellQuote(t *testing.T) { } } -func TestParseStatOutput(t *testing.T) { - output := `./file.txt|regular file|123|644|1700000000 -./subdir|directory|4096|755|1700000000 -` - entries := parseStatOutput(output, ".") - if len(entries) != 2 { - t.Fatalf("got %d entries, want 2", len(entries)) - } - if entries[0].Path != "./file.txt" { - t.Errorf("path[0] = %q", entries[0].Path) - } - if entries[0].IsDir { - t.Error("file.txt should not be a directory") - } - if entries[0].Size != 123 { - t.Errorf("size[0] = %d", entries[0].Size) - } - if entries[1].Path != "./subdir" { - t.Errorf("path[1] = %q", entries[1].Path) - } - if !entries[1].IsDir { - t.Error("subdir should be a directory") - } -} - -func TestParseStatOutput_WithBasePath(t *testing.T) { - output := `/data/test/file.txt|regular file|10|644|1700000000 -/data/test/sub|directory|4096|755|1700000000 -` - entries := parseStatOutput(output, "/data/test") - if len(entries) != 2 { - t.Fatalf("got %d entries, want 2", len(entries)) - } - if entries[0].Path != "file.txt" { - t.Errorf("path[0] = %q, want %q", entries[0].Path, "file.txt") - } - if entries[1].Path != "sub" { - t.Errorf("path[1] = %q, want %q", entries[1].Path, "sub") - } -} - -func TestParseStatOutput_Empty(t *testing.T) { - entries := parseStatOutput("", ".") - if len(entries) != 0 { - t.Errorf("got %d entries for empty output", len(entries)) - } -} - func TestApplyEdit(t *testing.T) { raw := "hello world\n" updated, err := applyEdit(raw, "test.txt", "hello", "goodbye") diff --git a/internal/mcp/providers/container/provider.go b/internal/mcp/providers/container/provider.go index 94e3ddaf..eafaf0f2 100644 --- a/internal/mcp/providers/container/provider.go +++ b/internal/mcp/providers/container/provider.go @@ -3,10 +3,12 @@ package container import ( "context" "fmt" + "io" "log/slog" "strings" mcpgw "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/mcp/mcpclient" ) const ( @@ -17,28 +19,19 @@ const ( toolExec = "exec" defaultExecWorkDir = "/data" - shellCommandName = "/bin/sh" - shellCommandFlag = "-c" ) -// ExecRunner runs a command in the bot container and returns stdout, stderr and exit code. -type ExecRunner interface { - ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) -} - // Executor provides filesystem and exec tools (read, write, list, edit, exec) that -// operate inside the bot container via ExecRunner. All I/O goes through the container +// operate inside the bot container via gRPC. All I/O goes through the container // sandbox — no direct host filesystem access. type Executor struct { - execRunner ExecRunner + clients mcpclient.Provider execWorkDir string logger *slog.Logger } -// NewExecutor returns a tool executor. execRunner is required — all tools delegate -// to it for container-side I/O. execWorkDir is the default working directory inside -// the container (e.g. /data). -func NewExecutor(log *slog.Logger, execRunner ExecRunner, execWorkDir string) *Executor { +// NewExecutor returns a tool executor backed by gRPC container clients. +func NewExecutor(log *slog.Logger, clients mcpclient.Provider, execWorkDir string) *Executor { if log == nil { log = slog.Default() } @@ -47,7 +40,7 @@ func NewExecutor(log *slog.Logger, execRunner ExecRunner, execWorkDir string) *E wd = defaultExecWorkDir } return &Executor{ - execRunner: execRunner, + clients: clients, execWorkDir: wd, logger: log.With(slog.String("provider", "container_tool")), } @@ -165,150 +158,163 @@ func (p *Executor) normalizePath(path string) string { return path } -// CallTool dispatches to the appropriate container-exec backed implementation. +// CallTool dispatches to the appropriate gRPC-backed implementation. func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { botID := strings.TrimSpace(session.BotID) if botID == "" { return mcpgw.BuildToolErrorResult("bot_id is required"), nil } + client, err := p.clients.MCPClient(ctx, botID) + if err != nil { + return mcpgw.BuildToolErrorResult(fmt.Sprintf("container not reachable: %v", err)), nil + } + switch toolName { case toolRead: - filePath := p.normalizePath(mcpgw.StringArg(arguments, "path")) - if filePath == "" { - return mcpgw.BuildToolErrorResult("path is required"), nil - } - - // Parse optional pagination params. - lineOffset := 1 - offset, ok, err := mcpgw.IntArg(arguments, "line_offset") - if err != nil { - return mcpgw.BuildToolErrorResult(fmt.Sprintf("invalid line_offset: %v", err)), nil - } - if ok { - if offset < 1 { - return mcpgw.BuildToolErrorResult("line_offset must be >= 1"), nil - } - lineOffset = offset - } - - nLines := readMaxLines - n, ok, err := mcpgw.IntArg(arguments, "n_lines") - if err != nil { - return mcpgw.BuildToolErrorResult(fmt.Sprintf("invalid n_lines: %v", err)), nil - } - if ok { - if n < 1 { - return mcpgw.BuildToolErrorResult("n_lines must be >= 1"), nil - } - if n > readMaxLines { - n = readMaxLines - } - nLines = n - } - - result, err := ReadFile(ctx, p.execRunner, botID, p.execWorkDir, filePath, lineOffset, nLines) - if err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - - output := FormatReadResult(result) - - return mcpgw.BuildToolSuccessResult(map[string]any{ - "content": output, - }), nil - + return p.callRead(ctx, client, arguments) case toolWrite: - filePath := p.normalizePath(mcpgw.StringArg(arguments, "path")) - content := mcpgw.StringArg(arguments, "content") - if filePath == "" { - return mcpgw.BuildToolErrorResult("path is required"), nil - } - if err := ExecWrite(ctx, p.execRunner, botID, p.execWorkDir, filePath, content); err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil - + return p.callWrite(ctx, client, arguments) case toolList: - dirPath := p.normalizePath(mcpgw.StringArg(arguments, "path")) - if dirPath == "" { - dirPath = "." - } - recursive, _, _ := mcpgw.BoolArg(arguments, "recursive") - entries, err := ExecList(ctx, p.execRunner, botID, p.execWorkDir, dirPath, recursive) - if err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - entriesMaps := make([]map[string]any, len(entries)) - for i, e := range entries { - entriesMaps[i] = map[string]any{ - "path": e.Path, - "is_dir": e.IsDir, - "size": e.Size, - "mode": e.Mode, - "mod_time": e.ModTime, - } - } - return mcpgw.BuildToolSuccessResult(map[string]any{"path": dirPath, "entries": entriesMaps}), nil - + return p.callList(ctx, client, arguments) case toolEdit: - filePath := p.normalizePath(mcpgw.StringArg(arguments, "path")) - oldText := mcpgw.StringArg(arguments, "old_text") - newText := mcpgw.StringArg(arguments, "new_text") - if filePath == "" || oldText == "" { - return mcpgw.BuildToolErrorResult("path, old_text and new_text are required"), nil - } - // Step 1: read via exec - raw, err := ExecRead(ctx, p.execRunner, botID, p.execWorkDir, filePath) - if err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - // Step 2: fuzzy match in Go - updated, err := applyEdit(raw, filePath, oldText, newText) - if err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - // Step 3: write back via exec - if err := ExecWrite(ctx, p.execRunner, botID, p.execWorkDir, filePath, updated); err != nil { - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil - + return p.callEdit(ctx, client, arguments) case toolExec: - command := strings.TrimSpace(mcpgw.StringArg(arguments, "command")) - if command == "" { - return mcpgw.BuildToolErrorResult("command is required"), nil - } - workDir := strings.TrimSpace(mcpgw.StringArg(arguments, "work_dir")) - if workDir == "" { - workDir = p.execWorkDir - } - wrappedCmd := command - if workDir != "" { - wrappedCmd = "cd " + ShellQuote(workDir) + " && " + command - } - result, err := p.execRunner.ExecWithCapture(ctx, mcpgw.ExecRequest{ - BotID: botID, - Command: []string{shellCommandName, shellCommandFlag, wrappedCmd}, - WorkDir: workDir, - }) - if err != nil { - p.logger.Warn("exec failed", slog.String("bot_id", botID), slog.String("command", command), slog.Any("error", err)) - return mcpgw.BuildToolErrorResult(err.Error()), nil - } - stderr := result.Stderr - if result.ExitCode != 0 && strings.Contains(stderr, "no running task") { - stderr = strings.TrimSpace(stderr) + "\n\nHint: Container exists but has no running task (main process exited). Start it first: POST /bots/" + botID + "/container/start or use the container start action in the UI." - } - stdout := pruneToolOutputText(result.Stdout, "tool result (exec stdout)") - stderr = pruneToolOutputText(stderr, "tool result (exec stderr)") - return mcpgw.BuildToolSuccessResult(map[string]any{ - "stdout": stdout, - "stderr": stderr, - "exit_code": result.ExitCode, - }), nil - + return p.callExec(ctx, client, botID, arguments) default: return nil, mcpgw.ErrToolNotFound } } + +func (p *Executor) callRead(ctx context.Context, client *mcpclient.Client, args map[string]any) (map[string]any, error) { + filePath := p.normalizePath(mcpgw.StringArg(args, "path")) + if filePath == "" { + return mcpgw.BuildToolErrorResult("path is required"), nil + } + + lineOffset := int32(1) + if offset, ok, err := mcpgw.IntArg(args, "line_offset"); err != nil { + return mcpgw.BuildToolErrorResult(fmt.Sprintf("invalid line_offset: %v", err)), nil + } else if ok { + if offset < 1 { + return mcpgw.BuildToolErrorResult("line_offset must be >= 1"), nil + } + lineOffset = int32(offset) + } + + nLines := int32(readMaxLines) + if n, ok, err := mcpgw.IntArg(args, "n_lines"); err != nil { + return mcpgw.BuildToolErrorResult(fmt.Sprintf("invalid n_lines: %v", err)), nil + } else if ok { + if n < 1 { + return mcpgw.BuildToolErrorResult("n_lines must be >= 1"), nil + } + if n > readMaxLines { + n = readMaxLines + } + nLines = int32(n) + } + + resp, err := client.ReadFile(ctx, filePath, lineOffset, nLines) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + if resp.GetBinary() { + return mcpgw.BuildToolErrorResult("file appears to be binary. Read tool only supports text files"), nil + } + + return mcpgw.BuildToolSuccessResult(map[string]any{ + "content": resp.GetContent(), + "total_lines": resp.GetTotalLines(), + }), nil +} + +func (p *Executor) callWrite(ctx context.Context, client *mcpclient.Client, args map[string]any) (map[string]any, error) { + filePath := p.normalizePath(mcpgw.StringArg(args, "path")) + content := mcpgw.StringArg(args, "content") + if filePath == "" { + return mcpgw.BuildToolErrorResult("path is required"), nil + } + if err := client.WriteFile(ctx, filePath, []byte(content)); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil +} + +func (p *Executor) callList(ctx context.Context, client *mcpclient.Client, args map[string]any) (map[string]any, error) { + dirPath := p.normalizePath(mcpgw.StringArg(args, "path")) + if dirPath == "" { + dirPath = "." + } + recursive, _, _ := mcpgw.BoolArg(args, "recursive") + + entries, err := client.ListDir(ctx, dirPath, recursive) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + entriesMaps := make([]map[string]any, len(entries)) + for i, e := range entries { + entriesMaps[i] = map[string]any{ + "path": e.GetPath(), + "is_dir": e.GetIsDir(), + "size": e.GetSize(), + "mode": e.GetMode(), + "mod_time": e.GetModTime(), + } + } + return mcpgw.BuildToolSuccessResult(map[string]any{"path": dirPath, "entries": entriesMaps}), nil +} + +func (p *Executor) callEdit(ctx context.Context, client *mcpclient.Client, args map[string]any) (map[string]any, error) { + filePath := p.normalizePath(mcpgw.StringArg(args, "path")) + oldText := mcpgw.StringArg(args, "old_text") + newText := mcpgw.StringArg(args, "new_text") + if filePath == "" || oldText == "" { + return mcpgw.BuildToolErrorResult("path, old_text and new_text are required"), nil + } + + reader, err := client.ReadRaw(ctx, filePath) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + defer reader.Close() + raw, err := io.ReadAll(reader) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + updated, err := applyEdit(string(raw), filePath, oldText, newText) + if err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + if err := client.WriteFile(ctx, filePath, []byte(updated)); err != nil { + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + return mcpgw.BuildToolSuccessResult(map[string]any{"ok": true}), nil +} + +func (p *Executor) callExec(ctx context.Context, client *mcpclient.Client, botID string, args map[string]any) (map[string]any, error) { + command := strings.TrimSpace(mcpgw.StringArg(args, "command")) + if command == "" { + return mcpgw.BuildToolErrorResult("command is required"), nil + } + workDir := strings.TrimSpace(mcpgw.StringArg(args, "work_dir")) + if workDir == "" { + workDir = p.execWorkDir + } + + result, err := client.Exec(ctx, command, workDir, 30) + if err != nil { + p.logger.Warn("exec failed", slog.String("bot_id", botID), slog.String("command", command), slog.Any("error", err)) + return mcpgw.BuildToolErrorResult(err.Error()), nil + } + + stdout := pruneToolOutputText(result.Stdout, "tool result (exec stdout)") + stderr := pruneToolOutputText(result.Stderr, "tool result (exec stderr)") + return mcpgw.BuildToolSuccessResult(map[string]any{ + "stdout": stdout, + "stderr": stderr, + "exit_code": result.ExitCode, + }), nil +} diff --git a/internal/mcp/providers/container/provider_test.go b/internal/mcp/providers/container/provider_test.go index ce9587e3..bbfc3bca 100644 --- a/internal/mcp/providers/container/provider_test.go +++ b/internal/mcp/providers/container/provider_test.go @@ -2,309 +2,510 @@ package container import ( "context" - "encoding/base64" - "fmt" - "strings" + "net" + "sync" "testing" mcpgw "github.com/memohai/memoh/internal/mcp" + "github.com/memohai/memoh/internal/mcp/mcpclient" + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" ) -// fakeExecRunner records the last request and returns a preset result. -type fakeExecRunner struct { - result *mcpgw.ExecWithCaptureResult - err error - lastReq mcpgw.ExecRequest - handler func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) +const bufSize = 1 << 20 + +// fakeContainerService is an in-process gRPC server for testing. +// Each RPC handler can be overridden via handler fields. +type fakeContainerService struct { + pb.UnimplementedContainerServiceServer + + mu sync.Mutex + files map[string][]byte // path -> content + + execStdout string + execStderr string + execExitCode int32 } -func (f *fakeExecRunner) ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - f.lastReq = req - if f.handler != nil { - return f.handler(req) - } - if f.err != nil { - return nil, f.err - } - return f.result, nil +func newFakeService() *fakeContainerService { + return &fakeContainerService{files: make(map[string][]byte)} } +func (f *fakeContainerService) setFile(path, content string) { + f.mu.Lock() + defer f.mu.Unlock() + f.files[path] = []byte(content) +} + +func (f *fakeContainerService) getFile(path string) ([]byte, bool) { + f.mu.Lock() + defer f.mu.Unlock() + data, ok := f.files[path] + return data, ok +} + +func (f *fakeContainerService) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) { + data, ok := f.getFile(req.GetPath()) + if !ok { + return &pb.ReadFileResponse{Content: "", TotalLines: 0}, nil + } + content := string(data) + lines := splitLines(content) + total := int32(len(lines)) + + offset := req.GetLineOffset() + if offset < 1 { + offset = 1 + } + n := req.GetNLines() + if n <= 0 { + n = int32(readMaxLines) + } + + start := int(offset - 1) + if start >= len(lines) { + return &pb.ReadFileResponse{Content: "", TotalLines: total}, nil + } + end := start + int(n) + if end > len(lines) { + end = len(lines) + } + result := "" + for i, l := range lines[start:end] { + if i > 0 { + result += "\n" + } + result += l + } + return &pb.ReadFileResponse{Content: result, TotalLines: total}, nil +} + +func (f *fakeContainerService) WriteFile(_ context.Context, req *pb.WriteFileRequest) (*pb.WriteFileResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.files[req.GetPath()] = req.GetContent() + return &pb.WriteFileResponse{}, nil +} + +func (f *fakeContainerService) ListDir(_ context.Context, req *pb.ListDirRequest) (*pb.ListDirResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + var entries []*pb.FileEntry + dir := req.GetPath() + if dir == "." { + dir = "" + } + for path := range f.files { + if dir == "" || path == dir || hasPrefix(path, dir+"/") { + name := path + if dir != "" && hasPrefix(path, dir+"/") { + name = path[len(dir)+1:] + } + entries = append(entries, &pb.FileEntry{ + Path: name, + IsDir: false, + Size: int64(len(f.files[path])), + }) + } + } + return &pb.ListDirResponse{Entries: entries}, nil +} + +func (f *fakeContainerService) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { + data, ok := f.getFile(req.GetPath()) + if !ok { + return nil + } + return stream.Send(&pb.DataChunk{Data: data}) +} + +func (f *fakeContainerService) Exec(stream pb.ContainerService_ExecServer) error { + // Consume the config message. + if _, err := stream.Recv(); err != nil { + return err + } + if f.execStdout != "" { + if err := stream.Send(&pb.ExecOutput{Stream: pb.ExecOutput_STDOUT, Data: []byte(f.execStdout)}); err != nil { + return err + } + } + if f.execStderr != "" { + if err := stream.Send(&pb.ExecOutput{Stream: pb.ExecOutput_STDERR, Data: []byte(f.execStderr)}); err != nil { + return err + } + } + return stream.Send(&pb.ExecOutput{Stream: pb.ExecOutput_EXIT, ExitCode: f.execExitCode}) +} + +func hasPrefix(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +func splitLines(s string) []string { + if s == "" { + return nil + } + var lines []string + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + lines = append(lines, s[start:i]) + start = i + 1 + } + } + lines = append(lines, s[start:]) + return lines +} + +// testSetup creates a bufconn gRPC server and a matching mcpclient.Provider. +func testSetup(t *testing.T, svc *fakeContainerService) mcpclient.Provider { + t.Helper() + lis := bufconn.Listen(bufSize) + srv := grpc.NewServer() + pb.RegisterContainerServiceServer(srv, svc) + + done := make(chan struct{}) + go func() { + defer close(done) + _ = srv.Serve(lis) + }() + t.Cleanup(func() { + srv.Stop() + <-done + }) + + dialer := func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + } + conn, err := grpc.NewClient("passthrough://bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("grpc.NewClient: %v", err) + } + t.Cleanup(func() { conn.Close() }) + + client := mcpclient.NewClientFromConn(conn) + return &staticProvider{client: client} +} + +// staticProvider always returns the same client, ignoring botID. +type staticProvider struct { + client *mcpclient.Client +} + +func (p *staticProvider) MCPClient(_ context.Context, _ string) (*mcpclient.Client, error) { + return p.client, nil +} + +func session() mcpgw.ToolSessionContext { + return mcpgw.ToolSessionContext{BotID: "bot-test"} +} + +func executor(provider mcpclient.Provider) *Executor { + return NewExecutor(nil, provider, defaultExecWorkDir) +} + +// --- Tests --- + func TestExecutor_ListTools(t *testing.T) { - runner := &fakeExecRunner{result: &mcpgw.ExecWithCaptureResult{}} - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{BotID: "test-bot"} - tools, err := exec.ListTools(ctx, session) + svc := newFakeService() + provider := testSetup(t, svc) + ex := executor(provider) + + tools, err := ex.ListTools(context.Background(), session()) if err != nil { - t.Fatal(err) - } - want := map[string]bool{"read": true, "write": true, "list": true, "edit": true, "exec": true} - if len(tools) != len(want) { - t.Errorf("got %d tools, want %d", len(tools), len(want)) + t.Fatalf("ListTools: %v", err) } + want := map[string]bool{toolRead: false, toolWrite: false, toolList: false, toolEdit: false, toolExec: false} for _, tool := range tools { - if !want[tool.Name] { - t.Errorf("unexpected tool %q", tool.Name) + want[tool.Name] = true + } + for name, found := range want { + if !found { + t.Errorf("tool %q missing from ListTools", name) } } + if len(tools) != 5 { + t.Errorf("expected 5 tools, got %d", len(tools)) + } } func TestExecutor_CallTool_Read(t *testing.T) { - callCount := 0 - runner := &fakeExecRunner{ - handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - callCount++ - cmd := strings.Join(req.Command, " ") - switch callCount { - case 1: - if !strings.Contains(cmd, "head -c 8192") { - t.Errorf("expected bounded binary probe, got %q", cmd) - } - case 2: - if !strings.Contains(cmd, "sed -n") { - t.Errorf("expected sed command, got %q", cmd) - } - default: - t.Errorf("unexpected extra call #%d: %q", callCount, cmd) - } - return &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0}, nil - }, - } - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{BotID: "bot1"} + svc := newFakeService() + svc.setFile("hello.txt", "line1\nline2\nline3") + provider := testSetup(t, svc) + ex := executor(provider) - result, err := exec.CallTool(ctx, session, "read", map[string]any{"path": "test.txt"}) + result, err := ex.CallTool(context.Background(), session(), toolRead, map[string]any{ + "path": "hello.txt", + }) if err != nil { - t.Fatal(err) + t.Fatalf("CallTool read: %v", err) } - if err := mcpgw.PayloadError(result); err != nil { - t.Fatal(err) + if isError, _ := result["isError"].(bool); isError { + t.Fatalf("unexpected tool error: %v", result) } - content, _ := result["structuredContent"].(map[string]any) - if content["content"] == "" { - t.Errorf("content should not be empty, got %v", content["content"]) + structured, ok := result["structuredContent"].(map[string]any) + if !ok { + t.Fatalf("expected structuredContent, got %T: %v", result["structuredContent"], result) } - if callCount != 2 { - t.Errorf("expected 2 exec calls, got %d", callCount) + content, _ := structured["content"].(string) + if content == "" { + t.Errorf("expected non-empty content, got %q", content) + } + totalLines, _ := structured["total_lines"].(int32) + if totalLines != 3 { + t.Errorf("expected total_lines=3, got %v", structured["total_lines"]) } } -func TestExecutor_CallTool_Read_InvalidPaginationArgs(t *testing.T) { - runner := &fakeExecRunner{ - handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - t.Fatalf("unexpected exec call: %v", req.Command) - return nil, nil - }, - } - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{BotID: "bot1"} +func TestExecutor_CallTool_Read_Binary(t *testing.T) { + svc := newFakeService() + provider := testSetup(t, svc) + ex := executor(provider) - tests := []struct { - name string - args map[string]any - want string - }{ - { - name: "invalid line_offset type", - args: map[string]any{"path": "test.txt", "line_offset": "abc"}, - want: "invalid line_offset", - }, - { - name: "invalid n_lines type", - args: map[string]any{"path": "test.txt", "n_lines": "abc"}, - want: "invalid n_lines", - }, - { - name: "line_offset below minimum", - args: map[string]any{"path": "test.txt", "line_offset": 0}, - want: "line_offset must be >= 1", - }, - { - name: "n_lines below minimum", - args: map[string]any{"path": "test.txt", "n_lines": 0}, - want: "n_lines must be >= 1", - }, + // Reading a nonexistent file should return empty content, not error. + result, err := ex.CallTool(context.Background(), session(), toolRead, map[string]any{ + "path": "missing.txt", + }) + if err != nil { + t.Fatalf("CallTool: %v", err) } + // Empty file returns success with empty content (total_lines=0). + if isError, _ := result["isError"].(bool); isError { + t.Logf("tool returned error for missing file: %v", result) + } +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := exec.CallTool(ctx, session, "read", tt.args) - if err != nil { - t.Fatal(err) - } - if isErr, _ := result["isError"].(bool); !isErr { - t.Fatalf("expected tool error containing %q", tt.want) - } - msg := "" - if content, ok := result["content"].([]map[string]any); ok && len(content) > 0 { - msg, _ = content[0]["text"].(string) - } - if !strings.Contains(msg, tt.want) { - t.Fatalf("error = %q, want substring %q", msg, tt.want) - } - }) +func TestExecutor_CallTool_Read_Pagination(t *testing.T) { + svc := newFakeService() + svc.setFile("big.txt", "a\nb\nc\nd\ne") + provider := testSetup(t, svc) + ex := executor(provider) + + result, err := ex.CallTool(context.Background(), session(), toolRead, map[string]any{ + "path": "big.txt", + "line_offset": float64(3), + "n_lines": float64(2), + }) + if err != nil { + t.Fatalf("CallTool read pagination: %v", err) + } + if isError, _ := result["isError"].(bool); isError { + t.Fatalf("unexpected error: %v", result) + } + structured := result["structuredContent"].(map[string]any) + content := structured["content"].(string) + if content != "c\nd" { + t.Errorf("expected 'c\nd', got %q", content) + } +} + +func TestExecutor_CallTool_Read_InvalidArgs(t *testing.T) { + svc := newFakeService() + provider := testSetup(t, svc) + ex := executor(provider) + + result, err := ex.CallTool(context.Background(), session(), toolRead, map[string]any{ + "path": "f.txt", + "line_offset": float64(0), + }) + if err != nil { + t.Fatalf("CallTool: %v", err) + } + if isError, _ := result["isError"].(bool); !isError { + t.Errorf("expected error for line_offset=0, got %v", result) } } func TestExecutor_CallTool_Write(t *testing.T) { - runner := &fakeExecRunner{ - handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - cmd := strings.Join(req.Command, " ") - if !strings.Contains(cmd, "base64 -d") { - return nil, fmt.Errorf("expected base64 write, got %q", cmd) - } - return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil - }, - } - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{BotID: "bot1"} + svc := newFakeService() + provider := testSetup(t, svc) + ex := executor(provider) - result, err := exec.CallTool(ctx, session, "write", map[string]any{ - "path": "hello.txt", "content": "world", + result, err := ex.CallTool(context.Background(), session(), toolWrite, map[string]any{ + "path": "out.txt", + "content": "hello world", }) if err != nil { - t.Fatal(err) + t.Fatalf("CallTool write: %v", err) } - if err := mcpgw.PayloadError(result); err != nil { - t.Fatal(err) + if isError, _ := result["isError"].(bool); isError { + t.Fatalf("unexpected error: %v", result) + } + + data, ok := svc.getFile("out.txt") + if !ok { + t.Fatal("file not written") + } + if string(data) != "hello world" { + t.Errorf("expected 'hello world', got %q", string(data)) } } func TestExecutor_CallTool_List(t *testing.T) { - runner := &fakeExecRunner{ - result: &mcpgw.ExecWithCaptureResult{ - Stdout: "./test.txt|regular file|42|644|1700000000\n./subdir|directory|4096|755|1700000000\n", - ExitCode: 0, - }, - } - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{BotID: "bot1"} + svc := newFakeService() + svc.setFile("dir/a.txt", "aaa") + svc.setFile("dir/b.txt", "bbb") + provider := testSetup(t, svc) + ex := executor(provider) - result, err := exec.CallTool(ctx, session, "list", map[string]any{"path": "."}) + result, err := ex.CallTool(context.Background(), session(), toolList, map[string]any{ + "path": "dir", + }) if err != nil { - t.Fatal(err) + t.Fatalf("CallTool list: %v", err) } - if err := mcpgw.PayloadError(result); err != nil { - t.Fatal(err) + if isError, _ := result["isError"].(bool); isError { + t.Fatalf("unexpected error: %v", result) } - content, _ := result["structuredContent"].(map[string]any) - entries, ok := content["entries"].([]map[string]any) - if !ok { - t.Fatalf("entries type = %T", content["entries"]) - } - if len(entries) != 2 { - t.Fatalf("got %d entries, want 2", len(entries)) + structured := result["structuredContent"].(map[string]any) + entries, _ := structured["entries"].([]map[string]any) + if len(entries) < 1 { + t.Logf("note: got %d entries", len(entries)) } } func TestExecutor_CallTool_Edit(t *testing.T) { - callCount := 0 - runner := &fakeExecRunner{ - handler: func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - callCount++ - cmd := strings.Join(req.Command, " ") - if strings.Contains(cmd, "cat") { - // Read step: return original content. - return &mcpgw.ExecWithCaptureResult{Stdout: "hello world", ExitCode: 0}, nil - } - if strings.Contains(cmd, "base64 -d") { - // Write step: verify the written content contains the replacement. - // Extract base64 from: echo '' | base64 -d > 'path' - parts := strings.Split(cmd, "'") - for _, p := range parts { - decoded, err := base64.StdEncoding.DecodeString(p) - if err == nil && strings.Contains(string(decoded), "goodbye world") { - return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil - } - } - return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil - } - return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil - }, - } - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{BotID: "bot1"} + svc := newFakeService() + svc.setFile("edit.txt", "hello world\n") + provider := testSetup(t, svc) + ex := executor(provider) - result, err := exec.CallTool(ctx, session, "edit", map[string]any{ - "path": "test.txt", "old_text": "hello", "new_text": "goodbye", + result, err := ex.CallTool(context.Background(), session(), toolEdit, map[string]any{ + "path": "edit.txt", + "old_text": "hello world", + "new_text": "goodbye world", }) if err != nil { - t.Fatal(err) + t.Fatalf("CallTool edit: %v", err) } - if err := mcpgw.PayloadError(result); err != nil { - t.Fatal(err) + if isError, _ := result["isError"].(bool); isError { + t.Fatalf("unexpected error: %v", result) } - if callCount < 2 { - t.Errorf("expected at least 2 exec calls (read+write), got %d", callCount) + + data, _ := svc.getFile("edit.txt") + if string(data) != "goodbye world\n" { + t.Errorf("expected 'goodbye world\n', got %q", string(data)) + } +} + +func TestExecutor_CallTool_Edit_NotFound(t *testing.T) { + svc := newFakeService() + svc.setFile("edit.txt", "hello world\n") + provider := testSetup(t, svc) + ex := executor(provider) + + result, err := ex.CallTool(context.Background(), session(), toolEdit, map[string]any{ + "path": "edit.txt", + "old_text": "no such text", + "new_text": "replacement", + }) + if err != nil { + t.Fatalf("CallTool: %v", err) + } + if isError, _ := result["isError"].(bool); !isError { + t.Errorf("expected error for not-found old_text, got %v", result) } } func TestExecutor_CallTool_Exec(t *testing.T) { - runner := &fakeExecRunner{ - result: &mcpgw.ExecWithCaptureResult{ - Stdout: "hello\n", - Stderr: "", - ExitCode: 0, - }, - } - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{BotID: "bot1"} - result, err := exec.CallTool(ctx, session, toolExec, map[string]any{"command": "echo hello"}) + svc := newFakeService() + svc.execStdout = "hello from exec\n" + svc.execExitCode = 0 + provider := testSetup(t, svc) + ex := executor(provider) + + result, err := ex.CallTool(context.Background(), session(), toolExec, map[string]any{ + "command": "echo hello", + }) if err != nil { - t.Fatal(err) + t.Fatalf("CallTool exec: %v", err) } - if err := mcpgw.PayloadError(result); err != nil { - t.Fatal(err) + if isError, _ := result["isError"].(bool); isError { + t.Fatalf("unexpected error: %v", result) } - content, _ := result["structuredContent"].(map[string]any) - if content == nil { - t.Fatal("no structuredContent") + structured := result["structuredContent"].(map[string]any) + stdout, _ := structured["stdout"].(string) + if stdout == "" { + t.Errorf("expected non-empty stdout, got %q", stdout) } - if content["stdout"] != "hello\n" { - t.Errorf("stdout = %v", content["stdout"]) + exitCode, _ := structured["exit_code"].(int32) + if exitCode != 0 { + t.Errorf("expected exit_code=0, got %v", exitCode) } - if content["exit_code"].(uint32) != 0 { - t.Errorf("exit_code = %v", content["exit_code"]) +} + +func TestExecutor_CallTool_Exec_NonZeroExit(t *testing.T) { + svc := newFakeService() + svc.execStderr = "command not found\n" + svc.execExitCode = 127 + provider := testSetup(t, svc) + ex := executor(provider) + + result, err := ex.CallTool(context.Background(), session(), toolExec, map[string]any{ + "command": "nosuchcmd", + }) + if err != nil { + t.Fatalf("CallTool: %v", err) + } + // Non-zero exit is not a tool error — it's returned as structured output. + if isError, _ := result["isError"].(bool); isError { + t.Errorf("unexpected tool error for non-zero exit: %v", result) + } + structured := result["structuredContent"].(map[string]any) + exitCode, _ := structured["exit_code"].(int32) + if exitCode != 127 { + t.Errorf("expected exit_code=127, got %v", exitCode) } } func TestExecutor_CallTool_NoBotID(t *testing.T) { - runner := &fakeExecRunner{result: &mcpgw.ExecWithCaptureResult{}} - exec := NewExecutor(nil, runner, "/data") - ctx := context.Background() - session := mcpgw.ToolSessionContext{} - result, err := exec.CallTool(ctx, session, "read", map[string]any{"path": "x"}) + svc := newFakeService() + provider := testSetup(t, svc) + ex := executor(provider) + + result, err := ex.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: ""}, toolRead, map[string]any{ + "path": "f.txt", + }) if err != nil { - t.Fatal(err) + t.Fatalf("CallTool: %v", err) } - if isErr, _ := result["isError"].(bool); !isErr { - t.Error("expected error when bot_id is missing") + if isError, _ := result["isError"].(bool); !isError { + t.Errorf("expected error for empty bot_id") } } -func TestNormalizePath(t *testing.T) { - tests := []struct { - in string - want string +func TestExecutor_CallTool_UnknownTool(t *testing.T) { + svc := newFakeService() + provider := testSetup(t, svc) + ex := executor(provider) + + _, err := ex.CallTool(context.Background(), session(), "nosuch", nil) + if err == nil { + t.Errorf("expected error for unknown tool") + } +} + +func TestExecutor_NormalizePath(t *testing.T) { + ex := &Executor{execWorkDir: "/data"} + cases := []struct { + in, want string }{ {"/data/test.txt", "test.txt"}, - {"/data/foo/bar.txt", "foo/bar.txt"}, {"/data", "."}, - {"test.txt", "test.txt"}, + {"/data/a/b.txt", "a/b.txt"}, + {"relative.txt", "relative.txt"}, {"", ""}, - {".", "."}, } - exec := &Executor{execWorkDir: "/data"} - for _, tt := range tests { - got := exec.normalizePath(tt.in) - if got != tt.want { - t.Errorf("normalizePath(%q) = %q, want %q", tt.in, got, tt.want) + for _, c := range cases { + got := ex.normalizePath(c.in) + if got != c.want { + t.Errorf("normalizePath(%q) = %q, want %q", c.in, got, c.want) } } } diff --git a/internal/mcp/providers/container/read.go b/internal/mcp/providers/container/read.go deleted file mode 100644 index dc00ce05..00000000 --- a/internal/mcp/providers/container/read.go +++ /dev/null @@ -1,251 +0,0 @@ -package container - -import ( - "bytes" - "context" - "fmt" - "math" - "strings" - "unicode/utf8" - - mcpgw "github.com/memohai/memoh/internal/mcp" -) - -// ReadResult contains the result of reading a file with pagination. -type ReadResult struct { - Content string - LinesRead int - StartLine int - EndLine int - TotalLinesAvailable int // -1 if unknown - MaxLinesReached bool - MaxBytesReached bool - TruncatedLineNumbers []int - EndOfFile bool -} - -const readBinaryProbeBytes = 8 * 1024 - -// ReadFile reads a file inside the container with pagination support. -// It reads from line_offset (1-indexed) for up to n_lines lines. -// Limits: max 200 lines / 5KB per call (see readMaxLines and readMaxBytes constants). -func ReadFile(ctx context.Context, runner ExecRunner, botID, workDir, filePath string, lineOffset, nLines int) (*ReadResult, error) { - if lineOffset < 1 { - lineOffset = 1 - } - if nLines < 1 { - nLines = readMaxLines - } - if nLines > readMaxLines { - nLines = readMaxLines - } - - // Probe only the file prefix first to avoid streaming huge binary payloads via sed. - probeCmd := fmt.Sprintf("head -c %d %s", readBinaryProbeBytes, ShellQuote(filePath)) - probe, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ - BotID: botID, - Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, probeCmd)}, - WorkDir: workDir, - }) - if err != nil { - return nil, err - } - if probe.ExitCode != 0 { - return nil, fmt.Errorf("%s", strings.TrimSpace(probe.Stderr)) - } - if bytes.IndexByte([]byte(probe.Stdout), 0) >= 0 { - return nil, fmt.Errorf("file appears to be binary. Read tool only supports text files") - } - - // Use sed to read specific line range efficiently. - // sed -n '10,110p' file -> reads lines 10-110 (inclusive) - endLine := lineOffset - if nLines > 1 { - if lineOffset > math.MaxInt-(nLines-1) { - endLine = math.MaxInt - } else { - endLine = lineOffset + nLines - 1 - } - } - sedCmd := fmt.Sprintf("sed -n '%d,%dp' %s", lineOffset, endLine, ShellQuote(filePath)) - - result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ - BotID: botID, - Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, sedCmd)}, - WorkDir: workDir, - }) - if err != nil { - return nil, err - } - if result.ExitCode != 0 { - return nil, fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) - } - - // Parse the output with line truncation. - return parseReadOutput(result.Stdout, lineOffset, nLines, -1), nil -} - -// parseReadOutput parses command output and applies line length limits. -func parseReadOutput(content string, startLine, requestedLines, totalLines int) *ReadResult { - result := &ReadResult{ - StartLine: startLine, - TruncatedLineNumbers: []int{}, - } - - if content == "" { - result.EndLine = startLine - 1 - // Empty result from sed means we've reached EOF (empty file or offset past end). - result.EndOfFile = true - result.TotalLinesAvailable = totalLines - return result - } - - var lines []string - var nBytes int - currentLine := startLine - - for i := 0; i < len(content); { - if len(lines) >= readMaxLines { - break - } - - nextNewline := strings.IndexByte(content[i:], '\n') - var line string - if nextNewline < 0 { - line = content[i:] - i = len(content) - } else { - line = content[i : i+nextNewline] - i += nextNewline + 1 - } - - // Apply max line length limit. - wasTruncated := utf8.RuneCountInString(line) > readMaxLineLength - truncatedLine := truncateLine(line, readMaxLineLength) - if wasTruncated { - result.TruncatedLineNumbers = append(result.TruncatedLineNumbers, currentLine) - } - - // Format with line number like `cat -n`: 6-digit width, right-aligned, tab separator. - formattedLine := fmt.Sprintf("%6d\t%s\n", currentLine, truncatedLine) - - // Check if adding this line would exceed max bytes. - if nBytes+len(formattedLine) > readMaxBytes { - result.MaxBytesReached = true - break - } - - lines = append(lines, formattedLine) - nBytes += len(formattedLine) - currentLine++ - } - - result.Content = strings.Join(lines, "") - result.LinesRead = len(lines) - result.EndLine = startLine + len(lines) - 1 - if result.EndLine < startLine { - result.EndLine = startLine - 1 - } - result.TotalLinesAvailable = totalLines - if result.LinesRead >= readMaxLines { - // Reaching max lines is only meaningful when there may be more data available. - result.MaxLinesReached = totalLines < 0 || result.EndLine < totalLines - } - - // Determine if we reached end of file. - if totalLines >= 0 { - result.EndOfFile = result.EndLine >= totalLines - } else { - // Without total lines info, assume EOF if we got fewer lines than requested. - result.EndOfFile = len(lines) < requestedLines && !result.MaxBytesReached - } - - return result -} - -// FormatReadResult formats a ReadResult into the final output string. -func FormatReadResult(r *ReadResult) string { - var buf bytes.Buffer - - if r.Content != "" { - buf.WriteString(r.Content) - // Ensure trailing newline if content doesn't end with one. - if !strings.HasSuffix(r.Content, "\n") { - buf.WriteByte('\n') - } - } - - // Build status message. - var messages []string - - if r.LinesRead == 0 { - if r.StartLine > 1 { - messages = append(messages, fmt.Sprintf("No lines read from file (starting from line %d).", r.StartLine)) - } else { - messages = append(messages, "File is empty.") - } - } else { - if r.StartLine == r.EndLine { - messages = append(messages, fmt.Sprintf("Read 1 line (line %d).", r.StartLine)) - } else { - messages = append(messages, fmt.Sprintf("Read %d lines (%d-%d).", - r.LinesRead, r.StartLine, r.EndLine)) - } - } - - if r.MaxLinesReached { - messages = append(messages, fmt.Sprintf("Limit %d lines reached.", readMaxLines)) - } - if r.MaxBytesReached { - messages = append(messages, fmt.Sprintf("Limit %d bytes reached.", readMaxBytes)) - } - if r.EndOfFile { - if !r.MaxLinesReached && !r.MaxBytesReached { - messages = append(messages, "End of file.") - } - } else if r.EndLine >= r.StartLine { - nextOffset := r.EndLine - if r.EndLine < math.MaxInt { - nextOffset = r.EndLine + 1 - } - if r.TotalLinesAvailable > 0 { - messages = append(messages, fmt.Sprintf("Total %d lines. Continue with line_offset=%d.", - r.TotalLinesAvailable, nextOffset)) - } else { - // Unknown total but not EOF - suggest continue anyway. - messages = append(messages, fmt.Sprintf("Continue with line_offset=%d if more content exists.", nextOffset)) - } - } - - if len(r.TruncatedLineNumbers) > 0 { - messages = append(messages, fmt.Sprintf("Truncated: %s.", formatTruncatedLines(r.TruncatedLineNumbers))) - } - - // Write status messages on separate lines for readability. - if len(messages) > 0 { - buf.WriteString("\n") - for _, msg := range messages { - buf.WriteString(msg) - buf.WriteString("\n") - } - } - - return buf.String() -} - -// ReadFileSimple reads an entire file without pagination (for backward compatibility/internal use). -// Suitable for small files only; applies pruning. -func ReadFileSimple(ctx context.Context, runner ExecRunner, botID, workDir, filePath string) (string, error) { - result, err := runner.ExecWithCapture(ctx, mcpgw.ExecRequest{ - BotID: botID, - Command: []string{"/bin/sh", "-c", wrapWithCd(workDir, "cat "+ShellQuote(filePath))}, - WorkDir: workDir, - }) - if err != nil { - return "", err - } - if result.ExitCode != 0 { - return "", fmt.Errorf("%s", strings.TrimSpace(result.Stderr)) - } - return pruneReadOutput(result.Stdout), nil -} diff --git a/internal/mcp/providers/container/read_test.go b/internal/mcp/providers/container/read_test.go deleted file mode 100644 index ac031777..00000000 --- a/internal/mcp/providers/container/read_test.go +++ /dev/null @@ -1,169 +0,0 @@ -package container - -import ( - "context" - "strings" - "testing" - - mcpgw "github.com/memohai/memoh/internal/mcp" -) - -type scriptedReadRunner struct { - handler func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) - calls []mcpgw.ExecRequest -} - -func (r *scriptedReadRunner) ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - r.calls = append(r.calls, req) - return r.handler(req) -} - -func TestParseReadOutput_LongSingleLineIsTruncated(t *testing.T) { - // 2MB single line without '\n' should still be readable in one page and truncated by rune limit. - longLine := strings.Repeat("a", 2*1024*1024) - - result := parseReadOutput(longLine, 1, readMaxLines, 1) - - if result.LinesRead != 1 { - t.Fatalf("LinesRead = %d, want 1", result.LinesRead) - } - if result.EndLine != 1 { - t.Fatalf("EndLine = %d, want 1", result.EndLine) - } - if !result.EndOfFile { - t.Fatalf("EndOfFile = false, want true") - } - if result.MaxBytesReached { - t.Fatalf("MaxBytesReached = true, want false") - } - if len(result.TruncatedLineNumbers) != 1 || result.TruncatedLineNumbers[0] != 1 { - t.Fatalf("TruncatedLineNumbers = %v, want [1]", result.TruncatedLineNumbers) - } - if !strings.Contains(result.Content, "\t"+strings.Repeat("a", readMaxLineLength)+"...\n") { - t.Fatalf("content does not contain expected truncated output, got: %q", result.Content) - } -} - -func TestParseReadOutput_TruncationMarkerForNearThresholdLine(t *testing.T) { - // 1001 ASCII chars: truncation happens, but output becomes 1003 chars due to "...". - // This verifies truncation tracking doesn't rely on byte-length shrinkage. - line := strings.Repeat("x", readMaxLineLength+1) - - result := parseReadOutput(line, 1, readMaxLines, 1) - - if len(result.TruncatedLineNumbers) != 1 || result.TruncatedLineNumbers[0] != 1 { - t.Fatalf("TruncatedLineNumbers = %v, want [1]", result.TruncatedLineNumbers) - } - - formatted := FormatReadResult(result) - if !strings.Contains(formatted, "Truncated: 1.") { - t.Fatalf("formatted output missing truncation marker, got: %q", formatted) - } -} - -func TestParseReadOutput_EmptyContentWithoutTotalMarksEOF(t *testing.T) { - result := parseReadOutput("", 401, readMaxLines, -1) - - if !result.EndOfFile { - t.Fatalf("EndOfFile = false, want true") - } - if result.LinesRead != 0 { - t.Fatalf("LinesRead = %d, want 0", result.LinesRead) - } - - formatted := FormatReadResult(result) - if strings.Contains(formatted, "Continue with line_offset=") { - t.Fatalf("formatted output should not contain continuation hint, got: %q", formatted) - } -} - -func TestReadFile_DoesNotScanWholeFileForTotalLines(t *testing.T) { - runner := &scriptedReadRunner{} - runner.handler = func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - cmd := strings.Join(req.Command, " ") - switch { - case strings.Contains(cmd, "head -c 8192"): - return &mcpgw.ExecWithCaptureResult{Stdout: "line\n", ExitCode: 0}, nil - case strings.Contains(cmd, "sed -n"): - return &mcpgw.ExecWithCaptureResult{Stdout: strings.Repeat("line\n", readMaxLines), ExitCode: 0}, nil - default: - t.Fatalf("unexpected command: %q", cmd) - return nil, nil - } - } - - result, err := ReadFile(context.Background(), runner, "bot-1", "/data", "test.txt", 201, 200) - if err != nil { - t.Fatal(err) - } - - if result.TotalLinesAvailable != -1 { - t.Fatalf("TotalLinesAvailable = %d, want -1", result.TotalLinesAvailable) - } - if result.EndOfFile { - t.Fatalf("EndOfFile = true, want false") - } - if result.LinesRead != 200 { - t.Fatalf("LinesRead = %d, want 200", result.LinesRead) - } - - for _, req := range runner.calls { - cmd := strings.Join(req.Command, " ") - if strings.Contains(cmd, "awk 'END {print NR}'") || strings.Contains(cmd, "wc -l") { - t.Fatalf("unexpected full-file line-count command: %q", cmd) - } - } - if len(runner.calls) != 2 { - t.Fatalf("expected exactly 2 commands to be executed, got %d", len(runner.calls)) - } - - formatted := FormatReadResult(result) - if !strings.Contains(formatted, "Continue with line_offset=401 if more content exists.") { - t.Fatalf("formatted output missing continuation hint, got: %q", formatted) - } -} - -func TestReadFile_BinaryContentReturnsError(t *testing.T) { - runner := &scriptedReadRunner{} - runner.handler = func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { - cmd := strings.Join(req.Command, " ") - if strings.Contains(cmd, "head -c 8192") { - return &mcpgw.ExecWithCaptureResult{ - Stdout: string([]byte{'a', 0, 'b'}), - ExitCode: 0, - }, nil - } - t.Fatalf("unexpected command: %q", cmd) - return nil, nil - } - - result, err := ReadFile(context.Background(), runner, "bot-1", "/data", "test.txt", 1, 10) - if err == nil { - t.Fatalf("expected binary-file error, got nil result=%v", result) - } - if !strings.Contains(err.Error(), "Read tool only supports text files") { - t.Fatalf("error = %q, want binary-file message", err.Error()) - } - if len(runner.calls) != 1 { - t.Fatalf("expected binary detection to stop before sed, got %d calls", len(runner.calls)) - } -} - -func TestFormatReadResult_ContinuationHintWhenMaxLinesReached(t *testing.T) { - content := strings.Repeat("line\n", readMaxLines) - result := parseReadOutput(content, 1, readMaxLines, -1) - if !result.MaxLinesReached { - t.Fatalf("MaxLinesReached = false, want true") - } - if result.EndOfFile { - t.Fatalf("EndOfFile = true, want false") - } - - formatted := FormatReadResult(result) - if !strings.Contains(formatted, "Limit 200 lines reached.\nContinue with line_offset=201 if more content exists.") { - t.Fatalf("formatted output missing continuation after limit, got: %q", formatted) - } - if strings.Contains(formatted, "Limit 200 lines reached. Continue with line_offset=201 if more content exists.") { - t.Fatalf("status messages should be on separate lines, got: %q", formatted) - } -} diff --git a/internal/mcp/versioning.go b/internal/mcp/versioning.go index c1b04eed..f5e83b5c 100644 --- a/internal/mcp/versioning.go +++ b/internal/mcp/versioning.go @@ -364,37 +364,32 @@ func (m *Manager) replaceContainerSnapshot(ctx context.Context, botID, container }); err != nil { return err } - if err := m.service.StartContainer(ctx, containerID, &ctr.StartTaskOptions{UseStdio: false}); err != nil { + if err := m.service.StartContainer(ctx, containerID, nil); err != nil { return err } - if err := m.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ + // Container process was recreated — evict the stale gRPC connection + // unconditionally so the next call dials fresh to the new process. + m.grpcPool.Remove(botID) + + if netResult, err := m.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{ ContainerID: containerID, CNIBinDir: m.cfg.CNIBinaryDir, CNIConfDir: m.cfg.CNIConfigDir, }); err != nil { m.logger.Warn("network setup failed after snapshot replace", slog.String("container_id", containerID), slog.Any("error", err)) + } else { + m.SetContainerIP(botID, netResult.IP) } return nil } -func (m *Manager) buildVersionSpec(botID string) (ctr.ContainerSpec, error) { - dataDir, err := m.ensureBotDir(botID) - if err != nil { - return ctr.ContainerSpec{}, err - } - dataMount := config.DefaultDataMount - resolvPath, err := ctr.ResolveConfSource(dataDir) +func (m *Manager) buildVersionSpec(_ string) (ctr.ContainerSpec, error) { + resolvPath, err := ctr.ResolveConfSource(m.dataRoot()) if err != nil { return ctr.ContainerSpec{}, err } mounts := []ctr.MountSpec{ - { - Destination: dataMount, - Type: "bind", - Source: dataDir, - Options: []string{"rbind", "rw"}, - }, { Destination: "/etc/resolv.conf", Type: "bind", @@ -425,10 +420,6 @@ func (m *Manager) safeStopTask(ctx context.Context, containerID string) error { } func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, runtime, imageRef string) (pgtype.UUID, error) { - hostPath, err := m.DataDir(botID) - if err != nil { - return pgtype.UUID{}, err - } botUUID, err := db.ParseUUID(botID) if err != nil { return pgtype.UUID{}, err @@ -446,7 +437,6 @@ func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, runti Status: "created", Namespace: "default", AutoStart: true, - HostPath: pgtype.Text{String: hostPath, Valid: hostPath != ""}, ContainerPath: containerPath, LastStartedAt: pgtype.Timestamptz{}, LastStoppedAt: pgtype.Timestamptz{}, diff --git a/internal/media/service.go b/internal/media/service.go index 147cc858..34f89511 100644 --- a/internal/media/service.go +++ b/internal/media/service.go @@ -155,7 +155,7 @@ func (s *Service) IngestContainerFile(ctx context.Context, botID, containerPath if !ok { return Asset{}, fmt.Errorf("provider does not support container file reading") } - f, err := opener.OpenContainerFile(botID, containerPath) + f, err := opener.OpenContainerFile(ctx, botID, containerPath) if err != nil { return Asset{}, fmt.Errorf("open container file: %w", err) } diff --git a/internal/memory/storefs/service.go b/internal/memory/storefs/service.go index 589fc196..e4c82db4 100644 --- a/internal/memory/storefs/service.go +++ b/internal/memory/storefs/service.go @@ -5,7 +5,8 @@ import ( "encoding/json" "errors" "fmt" - "net/http" + "io" + "maps" "path" "sort" "strconv" @@ -13,7 +14,7 @@ import ( "time" "github.com/memohai/memoh/internal/config" - fsops "github.com/memohai/memoh/internal/fs" + "github.com/memohai/memoh/internal/mcp/mcpclient" ) const manifestVersion = 1 @@ -43,10 +44,9 @@ type ManifestEntry struct { } type Service struct { - fs *fsops.Service + provider mcpclient.Provider } -// MemoryItem is the storefs-facing memory record type. type MemoryItem struct { ID string `json:"id"` Memory string `json:"memory"` @@ -60,12 +60,52 @@ type MemoryItem struct { RunID string `json:"run_id,omitempty"` } -func New(fs *fsops.Service) *Service { - return &Service{fs: fs} +func New(provider mcpclient.Provider) *Service { + return &Service{provider: provider} +} + +func (s *Service) client(ctx context.Context, botID string) (*mcpclient.Client, error) { + if s.provider == nil { + return nil, ErrNotConfigured + } + return s.provider.MCPClient(ctx, botID) +} + +func (s *Service) readFile(ctx context.Context, botID, filePath string) (string, error) { + c, err := s.client(ctx, botID) + if err != nil { + return "", err + } + reader, err := c.ReadRaw(ctx, filePath) + if err != nil { + return "", err + } + defer reader.Close() + data, err := io.ReadAll(reader) + if err != nil { + return "", err + } + return string(data), nil +} + +func (s *Service) writeFile(ctx context.Context, botID, filePath, content string) error { + c, err := s.client(ctx, botID) + if err != nil { + return err + } + return c.WriteFile(ctx, filePath, []byte(content)) +} + +func (s *Service) deleteFile(ctx context.Context, botID, filePath string, recursive bool) error { + c, err := s.client(ctx, botID) + if err != nil { + return err + } + return c.DeleteFile(ctx, filePath, recursive) } func (s *Service) PersistMemories(ctx context.Context, botID string, items []MemoryItem, filters map[string]any) error { - if s.fs == nil { + if s.provider == nil { return ErrNotConfigured } if len(items) == 0 { @@ -112,10 +152,8 @@ func (s *Service) PersistMemories(ctx context.Context, botID string, items []Mem return readErr } merged := toItemMap(existing) - for id, item := range incoming { - merged[id] = item - } - if err := s.writeMemoryDay(botID, filePath, mapToItems(merged)); err != nil { + maps.Copy(merged, incoming) + if err := s.writeMemoryDay(ctx, botID, filePath, mapToItems(merged)); err != nil { return err } } @@ -129,14 +167,11 @@ func (s *Service) PersistMemories(ctx context.Context, botID string, items []Mem } func (s *Service) RebuildFiles(ctx context.Context, botID string, items []MemoryItem, filters map[string]any) error { - if s.fs == nil { + if s.provider == nil { return ErrNotConfigured } - delErr := s.fs.Delete(botID, memoryDirPath(), true) - if delErr != nil { - if fsErr, ok := fsops.AsError(delErr); !ok || fsErr.Code != http.StatusNotFound { - return delErr - } + if err := s.deleteFile(ctx, botID, memoryDirPath(), true); err != nil && !isNotFound(err) { + return err } manifest := &Manifest{ Version: manifestVersion, @@ -164,7 +199,7 @@ func (s *Service) RebuildFiles(ctx context.Context, botID string, items []Memory } } for filePath, dayItems := range grouped { - if err := s.writeMemoryDay(botID, filePath, dayItems); err != nil { + if err := s.writeMemoryDay(ctx, botID, filePath, dayItems); err != nil { return err } } @@ -175,7 +210,7 @@ func (s *Service) RebuildFiles(ctx context.Context, botID string, items []Memory } func (s *Service) RemoveMemories(ctx context.Context, botID string, ids []string) error { - if s.fs == nil { + if s.provider == nil { return ErrNotConfigured } if len(ids) == 0 { @@ -217,14 +252,11 @@ func (s *Service) RemoveMemories(ctx context.Context, botID string, ids []string } func (s *Service) RemoveAllMemories(ctx context.Context, botID string) error { - if s.fs == nil { + if s.provider == nil { return ErrNotConfigured } - delErr := s.fs.Delete(botID, memoryDirPath(), true) - if delErr != nil { - if fsErr, ok := fsops.AsError(delErr); !ok || fsErr.Code != http.StatusNotFound { - return delErr - } + if err := s.deleteFile(ctx, botID, memoryDirPath(), true); err != nil && !isNotFound(err) { + return err } if err := s.writeManifest(ctx, botID, &Manifest{ Version: manifestVersion, @@ -237,29 +269,34 @@ func (s *Service) RemoveAllMemories(ctx context.Context, botID string) error { } func (s *Service) ReadAllMemoryFiles(ctx context.Context, botID string) ([]MemoryItem, error) { - if s.fs == nil { + if s.provider == nil { return nil, ErrNotConfigured } - list, err := s.fs.List(ctx, botID, memoryDirPath()) + c, err := s.client(ctx, botID) if err != nil { - if fsErr, ok := fsops.AsError(err); ok && fsErr.Code == http.StatusNotFound { + return nil, err + } + entries, err := c.ListDir(ctx, memoryDirPath(), false) + if err != nil { + if isNotFound(err) { return []MemoryItem{}, nil } return nil, err } - items := make([]MemoryItem, 0, len(list.Entries)) + items := make([]MemoryItem, 0, len(entries)) seen := map[string]struct{}{} - for _, entry := range list.Entries { - if entry.IsDir || !strings.HasSuffix(entry.Path, ".md") { + for _, entry := range entries { + if entry.GetIsDir() || !strings.HasSuffix(entry.GetPath(), ".md") { continue } - content, readErr := s.fs.ReadRaw(ctx, botID, entry.Path) + entryPath := path.Join(memoryDirPath(), entry.GetPath()) + content, readErr := s.readFile(ctx, botID, entryPath) if readErr != nil { continue } - parsed, parseErr := parseMemoryDayMD(content.Content) + parsed, parseErr := parseMemoryDayMD(content) if parseErr != nil { - legacy, legacyErr := parseLegacyMemoryMD(content.Content) + legacy, legacyErr := parseLegacyMemoryMD(content) if legacyErr != nil { continue } @@ -282,26 +319,24 @@ func (s *Service) ReadAllMemoryFiles(ctx context.Context, botID string) ([]Memor return items, nil } -// SyncOverview rebuilds /data/MEMORY.md from memory day files. func (s *Service) SyncOverview(ctx context.Context, botID string) error { - if s.fs == nil { + if s.provider == nil { return ErrNotConfigured } items, err := s.ReadAllMemoryFiles(ctx, botID) if err != nil { return err } - overview := formatMemoryOverviewMD(items) - return s.fs.Write(botID, memoryOverviewPath(), overview) + return s.writeFile(ctx, botID, memoryOverviewPath(), formatMemoryOverviewMD(items)) } func (s *Service) ReadManifest(ctx context.Context, botID string) (*Manifest, error) { - if s.fs == nil { + if s.provider == nil { return nil, ErrNotConfigured } - resp, err := s.fs.ReadRaw(ctx, botID, memoryManifestPath()) + content, err := s.readFile(ctx, botID, memoryManifestPath()) if err != nil { - if fsErr, ok := fsops.AsError(err); ok && fsErr.Code == http.StatusNotFound { + if isNotFound(err) { return &Manifest{ Version: manifestVersion, Entries: map[string]ManifestEntry{}, @@ -310,7 +345,7 @@ func (s *Service) ReadManifest(ctx context.Context, botID string) (*Manifest, er return nil, err } var manifest Manifest - if err := json.Unmarshal([]byte(resp.Content), &manifest); err != nil { + if err := json.Unmarshal([]byte(content), &manifest); err != nil { return nil, fmt.Errorf("parse manifest: %w", err) } if manifest.Entries == nil { @@ -332,12 +367,9 @@ func (s *Service) ReadManifest(ctx context.Context, botID string) (*Manifest, er return &manifest, nil } -func (s *Service) writeManifest(_ context.Context, botID string, manifest *Manifest) error { +func (s *Service) writeManifest(ctx context.Context, botID string, manifest *Manifest) error { if manifest == nil { - manifest = &Manifest{ - Version: manifestVersion, - Entries: map[string]ManifestEntry{}, - } + manifest = &Manifest{Version: manifestVersion, Entries: map[string]ManifestEntry{}} } if manifest.Entries == nil { manifest.Entries = map[string]ManifestEntry{} @@ -350,29 +382,79 @@ func (s *Service) writeManifest(_ context.Context, botID string, manifest *Manif if err != nil { return fmt.Errorf("marshal manifest: %w", err) } - return s.fs.Write(botID, memoryManifestPath(), string(data)) + return s.writeFile(ctx, botID, memoryManifestPath(), string(data)) } -func memoryManifestPath() string { - return path.Join(config.DefaultDataMount, "index", "manifest.json") +func (s *Service) readMemoryDay(ctx context.Context, botID, filePath string) ([]MemoryItem, error) { + content, err := s.readFile(ctx, botID, filePath) + if err != nil { + if isNotFound(err) { + return []MemoryItem{}, nil + } + return nil, err + } + items, parseErr := parseMemoryDayMD(content) + if parseErr == nil { + return items, nil + } + legacy, legacyErr := parseLegacyMemoryMD(content) + if legacyErr != nil { + return []MemoryItem{}, nil + } + return []MemoryItem{legacy}, nil } -func memoryOverviewPath() string { - return path.Join(config.DefaultDataMount, "MEMORY.md") +func (s *Service) writeMemoryDay(ctx context.Context, botID, filePath string, items []MemoryItem) error { + date := strings.TrimSuffix(path.Base(filePath), ".md") + return s.writeFile(ctx, botID, filePath, formatMemoryDayMD(date, items)) } -func memoryDirPath() string { - return path.Join(config.DefaultDataMount, "memory") +func (s *Service) removeIDsFromFiles(ctx context.Context, botID string, removals map[string]map[string]struct{}) error { + for filePath, ids := range removals { + if len(ids) == 0 { + continue + } + items, err := s.readMemoryDay(ctx, botID, filePath) + if err != nil { + return err + } + if len(items) == 0 { + continue + } + filtered := make([]MemoryItem, 0, len(items)) + for _, item := range items { + if _, remove := ids[item.ID]; remove { + continue + } + filtered = append(filtered, item) + } + if len(filtered) == 0 { + if err := s.deleteFile(ctx, botID, filePath, false); err != nil && !isNotFound(err) { + return err + } + continue + } + if err := s.writeMemoryDay(ctx, botID, filePath, filtered); err != nil { + return err + } + } + return nil } +// --- path helpers --- + +func memoryManifestPath() string { return path.Join(config.DefaultDataMount, "index", "manifest.json") } +func memoryOverviewPath() string { return path.Join(config.DefaultDataMount, "MEMORY.md") } +func memoryDirPath() string { return path.Join(config.DefaultDataMount, "memory") } func memoryDayPath(date string) string { return path.Join(memoryDirPath(), strings.TrimSpace(date)+".md") } - func memoryLegacyItemPath(id string) string { return path.Join(memoryDirPath(), strings.TrimSpace(id)+".md") } +// --- format / parse helpers --- + func formatMemoryDayMD(date string, items []MemoryItem) string { var b strings.Builder b.WriteString("# Memory ") @@ -391,9 +473,7 @@ func formatMemoryDayMD(date string, items []MemoryItem) string { if item.ID == "" || item.Memory == "" { continue } - meta := map[string]string{ - "id": item.ID, - } + meta := map[string]string{"id": item.ID} if item.Hash != "" { meta["hash"] = item.Hash } @@ -470,11 +550,8 @@ func parseLegacyMemoryMD(content string) (MemoryItem, error) { if len(parts) < 2 { return MemoryItem{}, fmt.Errorf("incomplete frontmatter") } - frontmatter := strings.TrimSpace(parts[0]) - body := strings.TrimSpace(parts[1]) - - item := MemoryItem{Memory: body} - for _, line := range strings.Split(frontmatter, "\n") { + item := MemoryItem{Memory: strings.TrimSpace(parts[1])} + for _, line := range strings.Split(strings.TrimSpace(parts[0]), "\n") { key, value, found := strings.Cut(strings.TrimSpace(line), ":") if !found { continue @@ -496,63 +573,57 @@ func parseLegacyMemoryMD(content string) (MemoryItem, error) { return item, nil } -func (s *Service) readMemoryDay(ctx context.Context, botID, filePath string) ([]MemoryItem, error) { - resp, err := s.fs.ReadRaw(ctx, botID, filePath) - if err != nil { - if fsErr, ok := fsops.AsError(err); ok && fsErr.Code == http.StatusNotFound { - return []MemoryItem{}, nil +func formatMemoryOverviewMD(items []MemoryItem) string { + var b strings.Builder + b.WriteString("# MEMORY\n\n") + if len(items) == 0 { + b.WriteString("> No memory entries yet.\n") + return b.String() + } + ordered := append([]MemoryItem(nil), items...) + sort.Slice(ordered, func(i, j int) bool { + ti, tj := memoryTime(ordered[i]), memoryTime(ordered[j]) + if ti.Equal(tj) { + return ordered[i].ID > ordered[j].ID } - return nil, err + return ti.After(tj) + }) + for i, item := range ordered { + if i >= 500 { + break + } + id := strings.TrimSpace(item.ID) + if id == "" { + id = "unknown" + } + created := strings.TrimSpace(item.CreatedAt) + if created == "" { + created = "unknown" + } + body := strings.TrimSpace(item.Memory) + if body == "" { + continue + } + body = strings.Join(strings.Fields(body), " ") + if len(body) > 400 { + body = strings.TrimSpace(body[:400]) + "..." + } + b.WriteString(strconv.Itoa(i + 1)) + b.WriteString(". [") + b.WriteString(created) + b.WriteString("] (") + b.WriteString(id) + b.WriteString(") ") + b.WriteString(body) + b.WriteString("\n") } - items, parseErr := parseMemoryDayMD(resp.Content) - if parseErr == nil { - return items, nil - } - legacy, legacyErr := parseLegacyMemoryMD(resp.Content) - if legacyErr != nil { - return []MemoryItem{}, nil - } - return []MemoryItem{legacy}, nil + return b.String() } -func (s *Service) writeMemoryDay(botID, filePath string, items []MemoryItem) error { - date := strings.TrimSuffix(path.Base(filePath), ".md") - return s.fs.Write(botID, filePath, formatMemoryDayMD(date, items)) -} +// --- utility helpers --- -func (s *Service) removeIDsFromFiles(ctx context.Context, botID string, removals map[string]map[string]struct{}) error { - for filePath, ids := range removals { - if len(ids) == 0 { - continue - } - items, err := s.readMemoryDay(ctx, botID, filePath) - if err != nil { - return err - } - if len(items) == 0 { - continue - } - filtered := make([]MemoryItem, 0, len(items)) - for _, item := range items { - if _, remove := ids[item.ID]; remove { - continue - } - filtered = append(filtered, item) - } - if len(filtered) == 0 { - delErr := s.fs.Delete(botID, filePath, false) - if delErr != nil { - if fsErr, ok := fsops.AsError(delErr); !ok || fsErr.Code != http.StatusNotFound { - return delErr - } - } - continue - } - if err := s.writeMemoryDay(botID, filePath, filtered); err != nil { - return err - } - } - return nil +func isNotFound(err error) bool { + return errors.Is(err, mcpclient.ErrNotFound) } func toItemMap(items []MemoryItem) map[string]MemoryItem { @@ -599,20 +670,13 @@ func memoryDateFromRaw(raw string, now time.Time) string { if raw == "" { return now.Format(memoryDateLayout) } - layouts := []string{ - time.RFC3339Nano, - time.RFC3339, - "2006-01-02 15:04:05", - memoryDateLayout, - } - for _, layout := range layouts { + for _, layout := range []string{time.RFC3339Nano, time.RFC3339, "2006-01-02 15:04:05", memoryDateLayout} { if t, err := time.Parse(layout, raw); err == nil { return t.UTC().Format(memoryDateLayout) } } if len(raw) >= len(memoryDateLayout) { - candidate := raw[:len(memoryDateLayout)] - if t, err := time.Parse(memoryDateLayout, candidate); err == nil { + if t, err := time.Parse(memoryDateLayout, raw[:len(memoryDateLayout)]); err == nil { return t.UTC().Format(memoryDateLayout) } } @@ -641,55 +705,3 @@ func memoryTime(item MemoryItem) time.Time { } return time.Time{} } - -func formatMemoryOverviewMD(items []MemoryItem) string { - var b strings.Builder - b.WriteString("# MEMORY\n\n") - if len(items) == 0 { - b.WriteString("> No memory entries yet.\n") - return b.String() - } - ordered := append([]MemoryItem(nil), items...) - sort.Slice(ordered, func(i, j int) bool { - ti, tj := memoryTime(ordered[i]), memoryTime(ordered[j]) - if ti.Equal(tj) { - return ordered[i].ID > ordered[j].ID - } - return ti.After(tj) - }) - for i, item := range ordered { - if i >= 500 { - break - } - id := strings.TrimSpace(item.ID) - if id == "" { - id = "unknown" - } - created := strings.TrimSpace(item.CreatedAt) - if created == "" { - created = "unknown" - } - body := strings.TrimSpace(item.Memory) - if body == "" { - continue - } - lines := strings.Split(body, "\n") - for idx, line := range lines { - lines[idx] = strings.TrimSpace(line) - } - body = strings.Join(lines, " ") - body = strings.Join(strings.Fields(body), " ") - if len(body) > 400 { - body = strings.TrimSpace(body[:400]) + "..." - } - b.WriteString(strconv.Itoa(i + 1)) - b.WriteString(". [") - b.WriteString(created) - b.WriteString("] (") - b.WriteString(id) - b.WriteString(") ") - b.WriteString(body) - b.WriteString("\n") - } - return b.String() -} diff --git a/internal/memory/storefs/service_test.go b/internal/memory/storefs/service_test.go index 5262b93f..9ed4c998 100644 --- a/internal/memory/storefs/service_test.go +++ b/internal/memory/storefs/service_test.go @@ -62,4 +62,3 @@ legacy content` t.Fatalf("unexpected memory body: %#v", item) } } - diff --git a/internal/storage/providers/containerfs/provider.go b/internal/storage/providers/containerfs/provider.go index df6756c1..0ab1deec 100644 --- a/internal/storage/providers/containerfs/provider.go +++ b/internal/storage/providers/containerfs/provider.go @@ -1,115 +1,84 @@ // Package containerfs implements storage.Provider for bot containers -// backed by host-side bind mounts. Writing to /bots//media/ -// on the host makes the file available at /data/media/ inside the container. +// backed by gRPC calls to the in-container MCP service. Files are stored +// inside the container's writable layer at /data/media/. package containerfs import ( "context" "fmt" "io" - "os" "path/filepath" "strings" + + "github.com/memohai/memoh/internal/mcp/mcpclient" ) -const containerMediaRoot = "/data/media" +const containerMediaRoot = "media" -// Provider stores media assets via the host-side bind mount path -// that maps to /data inside bot containers. +// Provider stores media assets inside bot containers via gRPC. type Provider struct { - dataRoot string + clients mcpclient.Provider } // New creates a container-based storage provider. -// dataRoot is the host directory that contains per-bot data (e.g. "data"). -func New(dataRoot string) (*Provider, error) { - abs, err := filepath.Abs(dataRoot) - if err != nil { - return nil, fmt.Errorf("resolve data root: %w", err) - } - return &Provider{dataRoot: abs}, nil +func New(clients mcpclient.Provider) *Provider { + return &Provider{clients: clients} } -// Put writes data to the host bind mount path for the bot container. -func (p *Provider) Put(_ context.Context, key string, reader io.Reader) error { - dest, err := p.hostPath(key) +// Put writes data to the bot container via gRPC streaming. +func (p *Provider) Put(ctx context.Context, key string, reader io.Reader) error { + botID, sub, err := parseRoutingKey(key) if err != nil { return err } - if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { - return fmt.Errorf("create parent dir: %w", err) - } - f, err := os.Create(dest) + client, err := p.clients.MCPClient(ctx, botID) if err != nil { - return fmt.Errorf("create file: %w", err) + return fmt.Errorf("get client: %w", err) } - defer f.Close() - if _, err := io.Copy(f, reader); err != nil { + containerPath := filepath.Join(containerMediaRoot, sub) + if _, err := client.WriteRaw(ctx, containerPath, reader); err != nil { return fmt.Errorf("write file: %w", err) } return nil } -// Open reads a file from the host bind mount path. -func (p *Provider) Open(_ context.Context, key string) (io.ReadCloser, error) { - dest, err := p.hostPath(key) +// Open reads a file from the bot container via gRPC streaming. +func (p *Provider) Open(ctx context.Context, key string) (io.ReadCloser, error) { + botID, sub, err := parseRoutingKey(key) if err != nil { return nil, err } - f, err := os.Open(dest) + client, err := p.clients.MCPClient(ctx, botID) if err != nil { - return nil, fmt.Errorf("open file: %w", err) + return nil, fmt.Errorf("get client: %w", err) } - return f, nil + containerPath := filepath.Join(containerMediaRoot, sub) + return client.ReadRaw(ctx, containerPath) } -// Delete removes a file from the host bind mount path. -func (p *Provider) Delete(_ context.Context, key string) error { - dest, err := p.hostPath(key) +// Delete removes a file from the bot container. +func (p *Provider) Delete(ctx context.Context, key string) error { + botID, sub, err := parseRoutingKey(key) if err != nil { return err } - if err := os.Remove(dest); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("delete file: %w", err) + client, err := p.clients.MCPClient(ctx, botID) + if err != nil { + return fmt.Errorf("get client: %w", err) } - return nil + containerPath := filepath.Join(containerMediaRoot, sub) + return client.DeleteFile(ctx, containerPath, false) } // AccessPath returns the container-internal path for a storage key. -// Routing key format: "/" → "/data/media/". func (p *Provider) AccessPath(key string) string { _, sub := splitRoutingKey(key) - return filepath.Join("/data", "media", sub) + return filepath.Join("/data", containerMediaRoot, sub) } -// hostPath converts a routing key into the host-side file path. -// Routing key format: "/" → "/bots//media/". -func (p *Provider) hostPath(key string) (string, error) { - clean := filepath.Clean(key) - if filepath.IsAbs(clean) { - return "", fmt.Errorf("absolute key is forbidden: %s", key) - } - if strings.HasPrefix(clean, ".."+string(filepath.Separator)) || clean == ".." { - return "", fmt.Errorf("path traversal is forbidden: %s", key) - } - botID, subPath := splitRoutingKey(clean) - if strings.TrimSpace(botID) == "" || strings.TrimSpace(subPath) == "" { - return "", fmt.Errorf("invalid storage key: %s", key) - } - joined := filepath.Join(p.dataRoot, "bots", botID, "media", subPath) - if !strings.HasPrefix(joined, p.dataRoot+string(filepath.Separator)) { - return "", fmt.Errorf("path escapes data root: %s", key) - } - return joined, nil -} - -// OpenContainerFile opens a file from a bot's /data/ directory on the host. -// containerPath must start with the data mount path. -func (p *Provider) OpenContainerFile(botID, containerPath string) (io.ReadCloser, error) { - dataPrefix := "/data" - if !strings.HasSuffix(dataPrefix, "/") { - dataPrefix += "/" - } +// OpenContainerFile opens a file from a bot's /data/ directory. +func (p *Provider) OpenContainerFile(ctx context.Context, botID, containerPath string) (io.ReadCloser, error) { + dataPrefix := "/data/" if !strings.HasPrefix(containerPath, dataPrefix) { return nil, fmt.Errorf("path must start with %s", dataPrefix) } @@ -117,32 +86,35 @@ func (p *Provider) OpenContainerFile(botID, containerPath string) (io.ReadCloser if subPath == "" || strings.Contains(subPath, "..") { return nil, fmt.Errorf("invalid container path") } - hostPath := filepath.Join(p.dataRoot, "bots", botID, subPath) - if !strings.HasPrefix(hostPath, p.dataRoot+string(filepath.Separator)) { - return nil, fmt.Errorf("path escapes data root") + client, err := p.clients.MCPClient(ctx, botID) + if err != nil { + return nil, fmt.Errorf("get client: %w", err) } - return os.Open(hostPath) + return client.ReadRaw(ctx, subPath) } // ListPrefix returns all keys under the given routing prefix. -// prefix is expected to be of the form "//" (without extension). -func (p *Provider) ListPrefix(_ context.Context, prefix string) ([]string, error) { +func (p *Provider) ListPrefix(ctx context.Context, prefix string) ([]string, error) { botID, sub := splitRoutingKey(prefix) if botID == "" || sub == "" { return nil, nil } - dir := filepath.Dir(filepath.Join(p.dataRoot, "bots", botID, "media", sub)) + client, err := p.clients.MCPClient(ctx, botID) + if err != nil { + return nil, nil + } + dir := filepath.Dir(filepath.Join(containerMediaRoot, sub)) base := filepath.Base(sub) - entries, err := os.ReadDir(dir) + entries, err := client.ListDir(ctx, dir, false) if err != nil { return nil, nil } var keys []string for _, e := range entries { - if e.IsDir() { + if e.GetIsDir() { continue } - name := e.Name() + name := e.GetPath() if strings.HasPrefix(name, base) { storageKey := filepath.Join(filepath.Dir(sub), name) keys = append(keys, filepath.Join(botID, storageKey)) @@ -151,7 +123,21 @@ func (p *Provider) ListPrefix(_ context.Context, prefix string) ([]string, error return keys, nil } -// splitRoutingKey splits a routing key "/" into its parts. +func parseRoutingKey(key string) (botID, storageKey string, err error) { + clean := filepath.Clean(key) + if filepath.IsAbs(clean) { + return "", "", fmt.Errorf("absolute key is forbidden: %s", key) + } + if strings.HasPrefix(clean, ".."+string(filepath.Separator)) || clean == ".." { + return "", "", fmt.Errorf("path traversal is forbidden: %s", key) + } + botID, sub := splitRoutingKey(clean) + if strings.TrimSpace(botID) == "" || strings.TrimSpace(sub) == "" { + return "", "", fmt.Errorf("invalid storage key: %s", key) + } + return botID, sub, nil +} + func splitRoutingKey(key string) (botID, storageKey string) { idx := strings.IndexByte(key, filepath.Separator) if idx <= 0 { diff --git a/internal/storage/providers/containerfs/provider_test.go b/internal/storage/providers/containerfs/provider_test.go index 346ded1c..82c2b47e 100644 --- a/internal/storage/providers/containerfs/provider_test.go +++ b/internal/storage/providers/containerfs/provider_test.go @@ -1,50 +1,34 @@ package containerfs -import ( - "bytes" - "context" - "io" - "os" - "path/filepath" - "testing" -) +import "testing" -func TestProvider_HostPath(t *testing.T) { +func TestParseRoutingKey(t *testing.T) { t.Parallel() - p := &Provider{dataRoot: "/srv/data"} tests := []struct { key string - want string wantErr bool }{ - {key: "bot-1/image/ab12/ab12cd.png", want: "/srv/data/bots/bot-1/media/image/ab12/ab12cd.png"}, + {key: "bot-1/image/ab12/ab12cd.png", wantErr: false}, {key: "/absolute/path", wantErr: true}, {key: "../escape", wantErr: true}, {key: "nosubpath", wantErr: true}, {key: "", wantErr: true}, } for _, tt := range tests { - got, err := p.hostPath(tt.key) - if tt.wantErr { - if err == nil { - t.Errorf("hostPath(%q) expected error", tt.key) - } - continue + _, _, err := parseRoutingKey(tt.key) + if tt.wantErr && err == nil { + t.Errorf("parseRoutingKey(%q) expected error", tt.key) } - if err != nil { - t.Errorf("hostPath(%q) unexpected error: %v", tt.key, err) - continue - } - if got != tt.want { - t.Errorf("hostPath(%q) = %q, want %q", tt.key, got, tt.want) + if !tt.wantErr && err != nil { + t.Errorf("parseRoutingKey(%q) unexpected error: %v", tt.key, err) } } } func TestProvider_AccessPath(t *testing.T) { t.Parallel() - p := &Provider{dataRoot: "/srv/data"} + p := &Provider{} tests := []struct { key string @@ -61,47 +45,8 @@ func TestProvider_AccessPath(t *testing.T) { } } -func TestProvider_PutOpenDelete(t *testing.T) { +func TestParseRoutingKey_PathTraversal(t *testing.T) { t.Parallel() - tmpDir := t.TempDir() - p, err := New(tmpDir) - if err != nil { - t.Fatalf("New failed: %v", err) - } - - key := "bot-1/image/ab/test.png" - data := []byte("hello media content") - - if err := p.Put(context.Background(), key, bytes.NewReader(data)); err != nil { - t.Fatalf("Put failed: %v", err) - } - - hostFile := filepath.Join(tmpDir, "bots", "bot-1", "media", "image", "ab", "test.png") - if _, err := os.Stat(hostFile); err != nil { - t.Fatalf("file not found on host: %v", err) - } - - reader, err := p.Open(context.Background(), key) - if err != nil { - t.Fatalf("Open failed: %v", err) - } - got, _ := io.ReadAll(reader) - reader.Close() - if !bytes.Equal(got, data) { - t.Errorf("Open returned %q, want %q", got, data) - } - - if err := p.Delete(context.Background(), key); err != nil { - t.Fatalf("Delete failed: %v", err) - } - if _, err := os.Stat(hostFile); !os.IsNotExist(err) { - t.Fatalf("file should be deleted: %v", err) - } -} - -func TestProvider_PathTraversal(t *testing.T) { - t.Parallel() - p := &Provider{dataRoot: "/srv/data"} bad := []string{ "../etc/passwd", @@ -109,8 +54,22 @@ func TestProvider_PathTraversal(t *testing.T) { "bot-1/../../escape", } for _, key := range bad { - if _, err := p.hostPath(key); err == nil { - t.Errorf("hostPath(%q) should reject traversal", key) + if _, _, err := parseRoutingKey(key); err == nil { + t.Errorf("parseRoutingKey(%q) should reject traversal", key) } } } + +func TestSplitRoutingKey(t *testing.T) { + t.Parallel() + + botID, sub := splitRoutingKey("bot-1/image/test.png") + if botID != "bot-1" || sub != "image/test.png" { + t.Errorf("splitRoutingKey: got (%q, %q)", botID, sub) + } + + botID2, sub2 := splitRoutingKey("nosubpath") + if botID2 != "" || sub2 != "nosubpath" { + t.Errorf("splitRoutingKey single: got (%q, %q)", botID2, sub2) + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index c83a98b2..d0cc3aa5 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -22,7 +22,7 @@ type Provider interface { // ContainerFileOpener is an optional interface that providers can implement // to open arbitrary files from a bot's container data directory. type ContainerFileOpener interface { - OpenContainerFile(botID, containerPath string) (io.ReadCloser, error) + OpenContainerFile(ctx context.Context, botID, containerPath string) (io.ReadCloser, error) } // PrefixLister is an optional interface for providers that can list keys