fix(mcp): fix snapshot management and encapsulate locking (#169)

- Fix DeleteContainer FAILED_PRECONDITION by cleaning up stopped task
  entries before container deletion
- Fix CreateSnapshot leaving container in broken state: commit turns
  the active snapshot read-only, so the full cycle (stop → commit →
  prepare → delete → recreate → start) is now applied consistently
- Use context.WithoutCancel for atomic container replacement sequences
  to prevent cancelled HTTP requests from corrupting container state
- Use dctx for DB operations (recordSnapshotVersion/insertEvent) to
  avoid orphan snapshots in containerd without matching DB records
- Restart task + network after snapshot replacement, fixing Exec after
  CreateVersion where the container had no running task
- Extract replaceContainerSnapshot helper to deduplicate the prepare →
  delete → recreate → start pattern across three call sites
- Move snapshot list data fetching into Manager.ListBotSnapshotData to
  encapsulate per-container locking; remove exported LockBot method
- Use UnixNano for snapshot names to avoid second-precision collisions
This commit is contained in:
BBQ
2026-03-03 15:59:57 +08:00
committed by GitHub
parent 78faee4a0e
commit ee587b8ef5
3 changed files with 180 additions and 115 deletions
+10
View File
@@ -407,6 +407,16 @@ func (s *DefaultService) DeleteContainer(ctx context.Context, id string, opts *D
return err
}
// A stopped task still holds an entry in containerd; container.Delete fails
// with FAILED_PRECONDITION if any task entry exists. Delete it first.
if task, err := container.Task(ctx, nil); err == nil {
if _, err := task.Delete(ctx, containerd.WithProcessKill); err != nil && !errdefs.IsNotFound(err) {
return err
}
} else if !errdefs.IsNotFound(err) {
return err
}
deleteOpts := []containerd.DeleteOpts{}
cleanupSnapshot := true
if opts != nil {
+26 -60
View File
@@ -583,12 +583,11 @@ func (h *ContainerdHandler) ListSnapshots(c echo.Context) error {
if err != nil {
return err
}
ctx := c.Request().Context()
containerID, err := h.botContainerID(ctx, botID)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, "container not found for bot")
if h.manager == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "snapshot manager not configured")
}
containerInfo, err := h.service.GetContainer(ctx, containerID)
data, err := h.manager.ListBotSnapshotData(c.Request().Context(), botID)
if err != nil {
if errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusNotFound, "container not found")
@@ -596,80 +595,48 @@ func (h *ContainerdHandler) ListSnapshots(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
requestedSnapshotter := strings.TrimSpace(c.QueryParam("snapshotter"))
snapshotter := strings.TrimSpace(containerInfo.Snapshotter)
if requestedSnapshotter != "" {
if snapshotter != "" && requestedSnapshotter != snapshotter {
if req := strings.TrimSpace(c.QueryParam("snapshotter")); req != "" && req != data.Snapshotter {
return echo.NewHTTPError(http.StatusBadRequest, "snapshotter does not match container snapshotter")
}
snapshotter = requestedSnapshotter
}
if snapshotter == "" {
snapshotter = strings.TrimSpace(h.cfg.Snapshotter)
}
if snapshotter == "" {
snapshotter = "overlayfs"
}
snapshotKey := strings.TrimSpace(containerInfo.SnapshotKey)
snapshotKey := strings.TrimSpace(data.Info.SnapshotKey)
if snapshotKey == "" {
return echo.NewHTTPError(http.StatusInternalServerError, "container snapshot key is empty")
}
allSnapshots, err := h.service.ListSnapshots(ctx, snapshotter)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
runtimeByName := make(map[string]ctr.SnapshotInfo, len(allSnapshots))
for _, info := range allSnapshots {
runtimeByName := make(map[string]ctr.SnapshotInfo, len(data.RuntimeSnapshots))
for _, info := range data.RuntimeSnapshots {
name := strings.TrimSpace(info.Name)
if name == "" {
continue
}
runtimeByName[name] = info
}
lineage, ok := snapshotLineage(snapshotKey, allSnapshots)
lineage, ok := snapshotLineage(snapshotKey, data.RuntimeSnapshots)
if !ok {
h.logger.Warn("container snapshot chain root not found",
slog.String("container_id", containerID),
slog.String("snapshotter", snapshotter),
slog.String("container_id", data.ContainerID),
slog.String("snapshotter", data.Snapshotter),
slog.String("snapshot_key", snapshotKey),
)
return echo.NewHTTPError(http.StatusInternalServerError, "container snapshot chain not found")
}
metadataByName := map[string]dbsqlc.ListSnapshotsWithVersionByContainerIDRow{}
if h.queries != nil {
managedRows, dbErr := h.queries.ListSnapshotsWithVersionByContainerID(ctx, containerID)
if dbErr != nil {
return echo.NewHTTPError(http.StatusInternalServerError, dbErr.Error())
}
for _, row := range managedRows {
name := strings.TrimSpace(row.RuntimeSnapshotName)
if name == "" {
continue
}
metadataByName[name] = row
}
}
items := make([]SnapshotInfo, 0, len(lineage)+len(metadataByName))
seen := make(map[string]struct{}, len(lineage)+len(metadataByName))
appendRuntime := func(runtimeInfo ctr.SnapshotInfo, fallbackSource string, meta *dbsqlc.ListSnapshotsWithVersionByContainerIDRow) {
items := make([]SnapshotInfo, 0, len(lineage)+len(data.ManagedMeta))
seen := make(map[string]struct{}, len(lineage)+len(data.ManagedMeta))
appendRuntime := func(runtimeInfo ctr.SnapshotInfo, fallbackSource string, meta *mcp.ManagedSnapshotMeta) {
source := fallbackSource
managed := false
var version *int
if meta != nil {
if strings.TrimSpace(meta.Source) != "" {
source = strings.TrimSpace(meta.Source)
if meta.Source != "" {
source = meta.Source
}
managed = true
if meta.Version.Valid {
v := int(meta.Version.Int32)
version = &v
}
version = meta.Version
}
items = append(items, SnapshotInfo{
Snapshotter: snapshotter,
Snapshotter: data.Snapshotter,
Name: runtimeInfo.Name,
Parent: runtimeInfo.Parent,
Kind: runtimeInfo.Kind,
@@ -685,28 +652,27 @@ func (h *ContainerdHandler) ListSnapshots(c echo.Context) error {
for _, runtimeInfo := range lineage {
name := strings.TrimSpace(runtimeInfo.Name)
row, hasMeta := metadataByName[name]
if hasMeta {
appendRuntime(runtimeInfo, "image_layer", &row)
if meta, hasMeta := data.ManagedMeta[name]; hasMeta {
appendRuntime(runtimeInfo, "image_layer", &meta)
continue
}
appendRuntime(runtimeInfo, "image_layer", nil)
}
for name, row := range metadataByName {
for name, meta := range data.ManagedMeta {
if _, exists := seen[name]; exists {
continue
}
runtimeInfo, exists := runtimeByName[name]
if !exists {
h.logger.Warn("managed snapshot not found in runtime",
slog.String("container_id", containerID),
slog.String("container_id", data.ContainerID),
slog.String("snapshot_name", name),
slog.String("snapshotter", snapshotter),
slog.String("snapshotter", data.Snapshotter),
)
continue
}
appendRuntime(runtimeInfo, "managed", &row)
appendRuntime(runtimeInfo, "managed", &meta)
}
sort.Slice(items, func(i, j int) bool {
if items[i].CreatedAt.Equal(items[j].CreatedAt) {
@@ -715,7 +681,7 @@ func (h *ContainerdHandler) ListSnapshots(c echo.Context) error {
return items[i].CreatedAt.Before(items[j].CreatedAt)
})
return c.JSON(http.StatusOK, ListSnapshotsResponse{
Snapshotter: snapshotter,
Snapshotter: data.Snapshotter,
Snapshots: items,
})
}
+143 -54
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
@@ -39,6 +40,19 @@ type SnapshotCreateInfo struct {
CreatedAt time.Time
}
type ManagedSnapshotMeta struct {
Source string
Version *int
}
type BotSnapshotData struct {
ContainerID string
Info ctr.ContainerInfo
Snapshotter string
RuntimeSnapshots []ctr.SnapshotInfo
ManagedMeta map[string]ManagedSnapshotMeta
}
func (m *Manager) CreateSnapshot(ctx context.Context, botID, snapshotName, source string) (*SnapshotCreateInfo, error) {
if m.db == nil || m.queries == nil {
return nil, fmt.Errorf("db is not configured")
@@ -61,16 +75,30 @@ func (m *Manager) CreateSnapshot(ctx context.Context, botID, snapshotName, sourc
normalizedSnapshotName := strings.TrimSpace(snapshotName)
if normalizedSnapshotName == "" {
normalizedSnapshotName = fmt.Sprintf("%s-%s", containerID, time.Now().Format("20060102150405"))
normalizedSnapshotName = fmt.Sprintf("%s-%d", containerID, time.Now().UnixNano())
}
normalizedSource := normalizeSnapshotSource(source)
if err := m.service.CommitSnapshot(ctx, info.Snapshotter, normalizedSnapshotName, info.SnapshotKey); err != nil {
// The sequence below (stop → commit → replace → start) is atomic from the
// container's perspective: interrupting it mid-way leaves the container missing.
// Use a detached context so a cancelled HTTP request cannot break it.
dctx := context.WithoutCancel(ctx)
if err := m.safeStopTask(dctx, containerID); err != nil {
return nil, err
}
if err := m.service.CommitSnapshot(dctx, info.Snapshotter, normalizedSnapshotName, info.SnapshotKey); err != nil {
return nil, err
}
activeSnapshotName := fmt.Sprintf("%s-active-%d", containerID, time.Now().UnixNano())
if err := m.replaceContainerSnapshot(dctx, botID, containerID, info, activeSnapshotName, normalizedSnapshotName); err != nil {
return nil, err
}
_, versionNumber, createdAt, err := m.recordSnapshotVersion(
ctx,
dctx,
containerID,
normalizedSnapshotName,
info.SnapshotKey,
@@ -80,7 +108,7 @@ func (m *Manager) CreateSnapshot(ctx context.Context, botID, snapshotName, sourc
if err != nil {
return nil, err
}
if err := m.insertEvent(ctx, containerID, "snapshot_create", map[string]any{
if err := m.insertEvent(dctx, containerID, "snapshot_create", map[string]any{
"snapshot_name": normalizedSnapshotName,
"snapshotter": info.Snapshotter,
"source": normalizedSource,
@@ -119,43 +147,24 @@ func (m *Manager) CreateVersion(ctx context.Context, botID string) (*VersionInfo
return nil, err
}
if err := m.safeStopTask(ctx, containerID); err != nil {
dctx := context.WithoutCancel(ctx)
if err := m.safeStopTask(dctx, containerID); err != nil {
return nil, err
}
versionSnapshotName := fmt.Sprintf("%s-v%d", containerID, time.Now().UnixNano())
if err := m.service.CommitSnapshot(ctx, info.Snapshotter, versionSnapshotName, info.SnapshotKey); err != nil {
if err := m.service.CommitSnapshot(dctx, info.Snapshotter, versionSnapshotName, info.SnapshotKey); err != nil {
return nil, err
}
activeSnapshotName := fmt.Sprintf("%s-active-%d", containerID, time.Now().UnixNano())
if err := m.service.PrepareSnapshot(ctx, info.Snapshotter, activeSnapshotName, versionSnapshotName); err != nil {
return nil, err
}
if err := m.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{CleanupSnapshot: false}); err != nil {
return nil, err
}
spec, err := m.buildVersionSpec(botID)
if err != nil {
return nil, err
}
_, err = m.service.CreateContainerFromSnapshot(ctx, ctr.CreateContainerRequest{
ID: containerID,
ImageRef: info.Image,
SnapshotID: activeSnapshotName,
Snapshotter: info.Snapshotter,
Labels: info.Labels,
Spec: spec,
})
if err != nil {
if err := m.replaceContainerSnapshot(dctx, botID, containerID, info, activeSnapshotName, versionSnapshotName); err != nil {
return nil, err
}
versionID, versionNumber, createdAt, err := m.recordSnapshotVersion(
ctx,
dctx,
containerID,
versionSnapshotName,
info.SnapshotKey,
@@ -166,7 +175,7 @@ func (m *Manager) CreateVersion(ctx context.Context, botID string) (*VersionInfo
return nil, err
}
if err := m.insertEvent(ctx, containerID, "version_create", map[string]any{
if err := m.insertEvent(dctx, containerID, "version_create", map[string]any{
"snapshot_name": versionSnapshotName,
"version": versionNumber,
"version_id": versionID,
@@ -182,6 +191,67 @@ func (m *Manager) CreateVersion(ctx context.Context, botID string) (*VersionInfo
}, nil
}
// ListBotSnapshotData returns the raw snapshot data for a bot under the
// per-container lock, so callers never observe transient state during
// snapshot/version operations.
func (m *Manager) ListBotSnapshotData(ctx context.Context, botID string) (*BotSnapshotData, error) {
if err := validateBotID(botID); err != nil {
return nil, err
}
containerID := m.containerID(botID)
unlock := m.lockContainer(containerID)
defer unlock()
info, err := m.service.GetContainer(ctx, containerID)
if err != nil {
return nil, err
}
snapshotter := strings.TrimSpace(info.Snapshotter)
if snapshotter == "" {
snapshotter = m.cfg.Snapshotter
}
if snapshotter == "" {
snapshotter = "overlayfs"
}
runtimeSnapshots, err := m.service.ListSnapshots(ctx, snapshotter)
if err != nil {
return nil, err
}
managedMeta := make(map[string]ManagedSnapshotMeta)
if m.queries != nil {
rows, err := m.queries.ListSnapshotsWithVersionByContainerID(ctx, containerID)
if err != nil {
return nil, err
}
for _, row := range rows {
name := strings.TrimSpace(row.RuntimeSnapshotName)
if name == "" {
continue
}
meta := ManagedSnapshotMeta{
Source: strings.TrimSpace(row.Source),
}
if row.Version.Valid {
v := int(row.Version.Int32)
meta.Version = &v
}
managedMeta[name] = meta
}
}
return &BotSnapshotData{
ContainerID: containerID,
Info: info,
Snapshotter: snapshotter,
RuntimeSnapshots: runtimeSnapshots,
ManagedMeta: managedMeta,
}, nil
}
func (m *Manager) ListVersions(ctx context.Context, botID string) ([]VersionInfo, error) {
if m.db == nil || m.queries == nil {
return nil, fmt.Errorf("db is not configured")
@@ -237,37 +307,18 @@ func (m *Manager) RollbackVersion(ctx context.Context, botID string, version int
return err
}
if err := m.safeStopTask(ctx, containerID); err != nil {
dctx := context.WithoutCancel(ctx)
if err := m.safeStopTask(dctx, containerID); err != nil {
return err
}
activeSnapshotName := fmt.Sprintf("%s-rollback-%d", containerID, time.Now().UnixNano())
if err := m.service.PrepareSnapshot(ctx, info.Snapshotter, activeSnapshotName, snapshotName); err != nil {
if err := m.replaceContainerSnapshot(dctx, botID, containerID, info, activeSnapshotName, snapshotName); err != nil {
return err
}
if err := m.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{CleanupSnapshot: false}); err != nil {
return err
}
spec, err := m.buildVersionSpec(botID)
if err != nil {
return err
}
_, err = m.service.CreateContainerFromSnapshot(ctx, ctr.CreateContainerRequest{
ID: containerID,
ImageRef: info.Image,
SnapshotID: activeSnapshotName,
Snapshotter: info.Snapshotter,
Labels: info.Labels,
Spec: spec,
})
if err != nil {
return err
}
return m.insertEvent(ctx, containerID, "version_rollback", map[string]any{
return m.insertEvent(dctx, containerID, "version_rollback", map[string]any{
"snapshot_name": snapshotName,
"version": version,
"source": SnapshotSourceRollback,
@@ -289,6 +340,44 @@ func (m *Manager) VersionSnapshotName(ctx context.Context, botID string, version
})
}
// replaceContainerSnapshot prepares a new active snapshot from parentSnapshot,
// deletes the old container, recreates it on the new snapshot, and restarts the task.
// Caller must pass a detached context (context.WithoutCancel) to guarantee atomicity.
func (m *Manager) replaceContainerSnapshot(ctx context.Context, botID, containerID string, info ctr.ContainerInfo, activeSnapshotName, parentSnapshot string) error {
if err := m.service.PrepareSnapshot(ctx, info.Snapshotter, activeSnapshotName, parentSnapshot); err != nil {
return err
}
if err := m.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{CleanupSnapshot: false}); err != nil {
return err
}
spec, err := m.buildVersionSpec(botID)
if err != nil {
return err
}
if _, err := m.service.CreateContainerFromSnapshot(ctx, ctr.CreateContainerRequest{
ID: containerID,
ImageRef: info.Image,
SnapshotID: activeSnapshotName,
Snapshotter: info.Snapshotter,
Labels: info.Labels,
Spec: spec,
}); err != nil {
return err
}
if err := m.service.StartContainer(ctx, containerID, &ctr.StartTaskOptions{UseStdio: false}); err != nil {
return err
}
if err := m.service.SetupNetwork(ctx, ctr.NetworkSetupRequest{
ContainerID: containerID,
CNIBinDir: m.cfg.CNIBinaryDir,
CNIConfDir: m.cfg.CNIConfigDir,
}); err != nil {
m.logger.Warn("network setup failed after snapshot replace",
slog.String("container_id", containerID), slog.Any("error", err))
}
return nil
}
func (m *Manager) buildVersionSpec(botID string) (ctr.ContainerSpec, error) {
dataDir, err := m.ensureBotDir(botID)
if err != nil {