feat: Add GPU CDI support for workspace containers (#332)

* feat: add CDI GPU support for workspace containers

* feat: expose GPU CDI settings in bot container UI

* feat: move GPU settings into advanced container options

* docs: document advanced CDI device configuration
This commit is contained in:
Ming Lin
2026-04-10 14:52:17 +08:00
committed by GitHub
parent 19619d73a9
commit 4d3f2de7e2
22 changed files with 752 additions and 84 deletions
+44 -10
View File
@@ -41,19 +41,25 @@ type ContainerdHandler struct {
policyService *policy.Service
}
type ContainerGPURequest struct {
Devices []string `json:"devices,omitempty"`
}
type CreateContainerRequest struct {
Snapshotter string `json:"snapshotter,omitempty"`
RestoreData bool `json:"restore_data,omitempty"`
Image string `json:"image,omitempty"`
Snapshotter string `json:"snapshotter,omitempty"`
RestoreData bool `json:"restore_data,omitempty"`
Image string `json:"image,omitempty"`
GPU *ContainerGPURequest `json:"gpu,omitempty"`
}
type CreateContainerResponse struct {
ContainerID string `json:"container_id"`
Image string `json:"image"`
Snapshotter string `json:"snapshotter"`
Started bool `json:"started"`
DataRestored bool `json:"data_restored"`
HasPreservedData bool `json:"has_preserved_data"`
ContainerID string `json:"container_id"`
Image string `json:"image"`
Snapshotter string `json:"snapshotter"`
CDIDevices []string `json:"cdi_devices,omitempty"`
Started bool `json:"started"`
DataRestored bool `json:"data_restored"`
HasPreservedData bool `json:"has_preserved_data"`
}
// codesync(container-create-stream): keep these SSE payloads in sync with
@@ -92,6 +98,7 @@ type GetContainerResponse struct {
Status string `json:"status"`
Namespace string `json:"namespace"`
ContainerPath string `json:"container_path"`
CDIDevices []string `json:"cdi_devices,omitempty"`
TaskRunning bool `json:"task_running"`
HasPreservedData bool `json:"has_preserved_data"`
Legacy bool `json:"legacy"`
@@ -217,9 +224,18 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
slog.String("bot_id", botID), slog.Any("error", err))
return nil
}
gpu, err := h.manager.ResolveWorkspaceGPU(ctx, botID)
if err != nil {
h.logger.Error("resolve workspace gpu failed",
slog.String("bot_id", botID), slog.Any("error", err))
return nil
}
if imageOverride != "" {
image = config.NormalizeImageRef(imageOverride)
}
if req.GPU != nil {
gpu = workspace.WorkspaceGPUConfig{Devices: req.GPU.Devices}
}
snapshotter := strings.TrimSpace(req.Snapshotter)
if snapshotter == "" {
@@ -283,7 +299,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
send(createContainerRestoringEvent{Type: "restoring"})
}
if err := h.manager.StartWithResolvedImage(ctx, botID, image); err != nil {
if err := h.manager.StartWithResolvedConfig(ctx, botID, image, gpu); err != nil {
h.logger.Error("container start failed",
slog.String("bot_id", botID), slog.Any("error", err))
sendError("container start failed: " + err.Error())
@@ -293,6 +309,12 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
h.logger.Warn("remember workspace image failed",
slog.String("bot_id", botID), slog.String("image", image), slog.Any("error", err))
}
if req.GPU != nil {
if err := h.manager.RememberWorkspaceGPU(ctx, botID, gpu); err != nil {
h.logger.Warn("remember workspace gpu failed",
slog.String("bot_id", botID), slog.Any("error", err))
}
}
containerID, err := h.manager.ContainerID(ctx, botID)
if err != nil {
@@ -315,6 +337,16 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
h.manager.RecordContainerRunning(ctx, botID, containerID, image)
status, statusErr := h.manager.GetContainerInfo(ctx, botID)
if statusErr != nil {
h.logger.Warn("load container status after start failed",
slog.String("bot_id", botID), slog.Any("error", statusErr))
}
cdiDevices := gpu.Devices
if status != nil {
cdiDevices = status.CDIDevices
}
// Phase 3: Complete
send(createContainerCompleteEvent{
Type: "complete",
@@ -322,6 +354,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
ContainerID: containerID,
Image: image,
Snapshotter: snapshotter,
CDIDevices: cdiDevices,
Started: true,
DataRestored: dataRestored,
HasPreservedData: h.manager.HasPreservedData(botID),
@@ -357,6 +390,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error {
Status: status.Status,
Namespace: status.Namespace,
ContainerPath: status.ContainerPath,
CDIDevices: status.CDIDevices,
TaskRunning: status.TaskRunning,
HasPreservedData: status.HasPreservedData,
Legacy: status.Legacy,