mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat: Add GPU CDI support for workspace containers (#332)
* feat: add CDI GPU support for workspace containers * feat: expose GPU CDI settings in bot container UI * feat: move GPU settings into advanced container options * docs: document advanced CDI device configuration
This commit is contained in:
@@ -41,19 +41,25 @@ type ContainerdHandler struct {
|
||||
policyService *policy.Service
|
||||
}
|
||||
|
||||
type ContainerGPURequest struct {
|
||||
Devices []string `json:"devices,omitempty"`
|
||||
}
|
||||
|
||||
type CreateContainerRequest struct {
|
||||
Snapshotter string `json:"snapshotter,omitempty"`
|
||||
RestoreData bool `json:"restore_data,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Snapshotter string `json:"snapshotter,omitempty"`
|
||||
RestoreData bool `json:"restore_data,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
GPU *ContainerGPURequest `json:"gpu,omitempty"`
|
||||
}
|
||||
|
||||
type CreateContainerResponse struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
Image string `json:"image"`
|
||||
Snapshotter string `json:"snapshotter"`
|
||||
Started bool `json:"started"`
|
||||
DataRestored bool `json:"data_restored"`
|
||||
HasPreservedData bool `json:"has_preserved_data"`
|
||||
ContainerID string `json:"container_id"`
|
||||
Image string `json:"image"`
|
||||
Snapshotter string `json:"snapshotter"`
|
||||
CDIDevices []string `json:"cdi_devices,omitempty"`
|
||||
Started bool `json:"started"`
|
||||
DataRestored bool `json:"data_restored"`
|
||||
HasPreservedData bool `json:"has_preserved_data"`
|
||||
}
|
||||
|
||||
// codesync(container-create-stream): keep these SSE payloads in sync with
|
||||
@@ -92,6 +98,7 @@ type GetContainerResponse struct {
|
||||
Status string `json:"status"`
|
||||
Namespace string `json:"namespace"`
|
||||
ContainerPath string `json:"container_path"`
|
||||
CDIDevices []string `json:"cdi_devices,omitempty"`
|
||||
TaskRunning bool `json:"task_running"`
|
||||
HasPreservedData bool `json:"has_preserved_data"`
|
||||
Legacy bool `json:"legacy"`
|
||||
@@ -217,9 +224,18 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
slog.String("bot_id", botID), slog.Any("error", err))
|
||||
return nil
|
||||
}
|
||||
gpu, err := h.manager.ResolveWorkspaceGPU(ctx, botID)
|
||||
if err != nil {
|
||||
h.logger.Error("resolve workspace gpu failed",
|
||||
slog.String("bot_id", botID), slog.Any("error", err))
|
||||
return nil
|
||||
}
|
||||
if imageOverride != "" {
|
||||
image = config.NormalizeImageRef(imageOverride)
|
||||
}
|
||||
if req.GPU != nil {
|
||||
gpu = workspace.WorkspaceGPUConfig{Devices: req.GPU.Devices}
|
||||
}
|
||||
|
||||
snapshotter := strings.TrimSpace(req.Snapshotter)
|
||||
if snapshotter == "" {
|
||||
@@ -283,7 +299,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
send(createContainerRestoringEvent{Type: "restoring"})
|
||||
}
|
||||
|
||||
if err := h.manager.StartWithResolvedImage(ctx, botID, image); err != nil {
|
||||
if err := h.manager.StartWithResolvedConfig(ctx, botID, image, gpu); err != nil {
|
||||
h.logger.Error("container start failed",
|
||||
slog.String("bot_id", botID), slog.Any("error", err))
|
||||
sendError("container start failed: " + err.Error())
|
||||
@@ -293,6 +309,12 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
h.logger.Warn("remember workspace image failed",
|
||||
slog.String("bot_id", botID), slog.String("image", image), slog.Any("error", err))
|
||||
}
|
||||
if req.GPU != nil {
|
||||
if err := h.manager.RememberWorkspaceGPU(ctx, botID, gpu); err != nil {
|
||||
h.logger.Warn("remember workspace gpu failed",
|
||||
slog.String("bot_id", botID), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
containerID, err := h.manager.ContainerID(ctx, botID)
|
||||
if err != nil {
|
||||
@@ -315,6 +337,16 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
|
||||
h.manager.RecordContainerRunning(ctx, botID, containerID, image)
|
||||
|
||||
status, statusErr := h.manager.GetContainerInfo(ctx, botID)
|
||||
if statusErr != nil {
|
||||
h.logger.Warn("load container status after start failed",
|
||||
slog.String("bot_id", botID), slog.Any("error", statusErr))
|
||||
}
|
||||
cdiDevices := gpu.Devices
|
||||
if status != nil {
|
||||
cdiDevices = status.CDIDevices
|
||||
}
|
||||
|
||||
// Phase 3: Complete
|
||||
send(createContainerCompleteEvent{
|
||||
Type: "complete",
|
||||
@@ -322,6 +354,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
ContainerID: containerID,
|
||||
Image: image,
|
||||
Snapshotter: snapshotter,
|
||||
CDIDevices: cdiDevices,
|
||||
Started: true,
|
||||
DataRestored: dataRestored,
|
||||
HasPreservedData: h.manager.HasPreservedData(botID),
|
||||
@@ -357,6 +390,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error {
|
||||
Status: status.Status,
|
||||
Namespace: status.Namespace,
|
||||
ContainerPath: status.ContainerPath,
|
||||
CDIDevices: status.CDIDevices,
|
||||
TaskRunning: status.TaskRunning,
|
||||
HasPreservedData: status.HasPreservedData,
|
||||
Legacy: status.Legacy,
|
||||
|
||||
Reference in New Issue
Block a user