feat: Add GPU CDI support for workspace containers (#332)

* feat: add CDI GPU support for workspace containers

* feat: expose GPU CDI settings in bot container UI

* feat: move GPU settings into advanced container options

* docs: document advanced CDI device configuration
This commit is contained in:
Ming Lin
2026-04-10 14:52:17 +08:00
committed by GitHub
parent 19619d73a9
commit 4d3f2de7e2
22 changed files with 752 additions and 84 deletions
+11
View File
@@ -748,6 +748,16 @@
"createRestoreDataDescription": "If a previously exported backup or legacy bind-mounted data exists, it will be restored into `/data` after the container is created.",
"createImageLabel": "Base image",
"createImageDescription": "Docker image to use as the container base (e.g. debian:bookworm-slim, alpine:latest, ubuntu:24.04). Leave empty for the default.",
"createAdvancedTitle": "Advanced options",
"createAdvancedDescription": "Configure optional GPU access and raw CDI device mappings for this container.",
"createGpuLabel": "Enable GPU",
"createGpuDescription": "Attach host GPU access to the new container.",
"createGpuDevicesLabel": "CDI devices",
"createGpuDevicesDescription": "Enter one CDI device per line or separate them with commas, for example `nvidia.com/gpu=0` or `amd.com/gpu=0`. Turning GPU off explicitly clears the saved GPU preference.",
"createGpuDevicesPlaceholder": "nvidia.com/gpu=0\namd.com/gpu=0",
"gpuDevicesRequired": "At least one CDI device is required when GPU is enabled.",
"cdiDevicesEmpty": "No GPU attached",
"gpuRecreateHint": "Changing GPU settings requires recreating the container. A simple start or stop will not change the devices already attached.",
"deleteConfirm": "Are you sure you want to permanently delete this container? Unpreserved data cannot be recovered.",
"deletePreserveConfirm": "Are you sure you want to export `/data` and then delete this container?",
"restoreConfirm": "Are you sure you want to restore preserved data into this container's `/data`?",
@@ -797,6 +807,7 @@
"task": "Task",
"namespace": "Namespace",
"image": "Image",
"cdiDevices": "CDI Devices",
"hostPath": "Host Path",
"containerPath": "Container Path",
"preservedData": "Preserved Data",
+11
View File
@@ -744,6 +744,16 @@
"createRestoreDataDescription": "如果存在之前导出的备份或旧版 bind mount 数据,将在容器创建后恢复到 `/data`。",
"createImageLabel": "基础镜像",
"createImageDescription": "作为容器基础环境的 Docker 镜像(如 debian:bookworm-slim、alpine:latest、ubuntu:24.04)。留空则使用默认镜像。",
"createAdvancedTitle": "高级选项",
"createAdvancedDescription": "配置该容器的可选 GPU 访问能力与原始 CDI 设备映射。",
"createGpuLabel": "启用 GPU",
"createGpuDescription": "为新容器开启宿主机 GPU 访问。",
"createGpuDevicesLabel": "CDI 设备",
"createGpuDevicesDescription": "每行或用逗号填写一个 CDI 设备名,例如 `nvidia.com/gpu=0` 或 `amd.com/gpu=0`。关闭 GPU 后会显式清空已保存的 GPU 偏好。",
"createGpuDevicesPlaceholder": "nvidia.com/gpu=0\namd.com/gpu=0",
"gpuDevicesRequired": "已启用 GPU 时,至少需要填写一个 CDI 设备名。",
"cdiDevicesEmpty": "未附加 GPU",
"gpuRecreateHint": "GPU 配置变更需要重建容器后才会生效,单纯启动或停止不会更新当前已附加的设备。",
"deleteConfirm": "确定要彻底删除这个容器吗?未保留的数据将无法恢复。",
"deletePreserveConfirm": "确定要先导出 `/data` 再删除这个容器吗?",
"restoreConfirm": "确定要将已保留的数据恢复到当前容器的 `/data` 吗?",
@@ -793,6 +803,7 @@
"task": "任务状态",
"namespace": "命名空间",
"image": "镜像",
"cdiDevices": "CDI 设备",
"hostPath": "主机路径",
"containerPath": "容器路径",
"preservedData": "保留数据",
@@ -4,6 +4,7 @@ import { toast } from 'vue-sonner'
import { useI18n } from 'vue-i18n'
import { useRoute } from 'vue-router'
import { useQuery } from '@pinia/colada'
import { ChevronRight } from 'lucide-vue-next'
import {
deleteBotsByBotIdContainer,
getBotsByBotIdContainer,
@@ -25,7 +26,7 @@ import {
type ContainerCreateLayerStatus,
type ContainerCreateStreamEvent,
} from '@/composables/api/useContainerStream'
import { Button, Input, Label, Separator, Spinner, Switch } from '@memohai/ui'
import { Button, Collapsible, CollapsibleContent, CollapsibleTrigger, Input, Label, Separator, Spinner, Switch, Textarea } from '@memohai/ui'
import ConfirmPopover from '@/components/confirm-popover/index.vue'
import ContainerCreateProgress from './container-create-progress.vue'
import { useSyncedQueryParam } from '@/composables/useSyncedQueryParam'
@@ -59,6 +60,10 @@ const rollbackVersion = ref<number | null>(null)
const createRestoreData = ref(false)
const createImage = ref('')
const createImagePrefilled = ref(false)
const createGPUEnabled = ref(false)
const createGPUDevices = ref('')
const createGPUPrefilled = ref(false)
const createAdvancedOpen = ref(false)
const newSnapshotName = ref('')
const importInputRef = ref<HTMLInputElement | null>(null)
@@ -178,7 +183,7 @@ async function handleRefreshContainer() {
await runContainerAction('refresh', () => loadContainerData(false))
}
const { data: bot } = useQuery({
const { data: bot, refetch: refetchBot } = useQuery({
key: () => ['bot', botId.value],
query: async () => {
const { data } = await getBotsById({ path: { id: botId.value }, throwOnError: true })
@@ -194,8 +199,48 @@ function rememberedWorkspaceImage(metadata: Record<string, unknown> | undefined)
return typeof image === 'string' ? shortenImageRef(image) : ''
}
type RememberedWorkspaceGPU = {
exists: boolean
devices: string[]
}
function rememberedWorkspaceGPU(metadata: Record<string, unknown> | undefined): RememberedWorkspaceGPU {
const workspace = metadata?.workspace
if (!workspace || typeof workspace !== 'object' || Array.isArray(workspace)) {
return { exists: false, devices: [] }
}
const workspaceRecord = workspace as Record<string, unknown>
if (!Object.prototype.hasOwnProperty.call(workspaceRecord, 'gpu')) {
return { exists: false, devices: [] }
}
const gpu = workspaceRecord.gpu
if (!gpu || typeof gpu !== 'object' || Array.isArray(gpu)) {
return { exists: true, devices: [] }
}
const rawDevices = (gpu as Record<string, unknown>).devices
const devices = Array.isArray(rawDevices)
? rawDevices.filter((value): value is string => typeof value === 'string').map(value => value.trim()).filter(Boolean)
: []
return { exists: true, devices: [...new Set(devices)] }
}
function parseCDIDevices(value: string): string[] {
return [...new Set(
value
.split(/[\n,]/)
.map(item => item.trim())
.filter(Boolean),
)]
}
const rememberedCreateImage = computed(() => rememberedWorkspaceImage(bot.value?.metadata as Record<string, unknown> | undefined))
const rememberedCreateGPU = computed(() => rememberedWorkspaceGPU(bot.value?.metadata as Record<string, unknown> | undefined))
const displayedContainerImage = computed(() => shortenImageRef(containerInfo.value?.image))
const displayedCDIDevices = computed(() => containerInfo.value?.cdi_devices ?? [])
const { isPending: botLifecyclePending } = useBotStatusMeta(bot, t)
@@ -248,16 +293,29 @@ async function handleCreateContainer() {
containerAction.value = 'create'
createProgress.value = { phase: 'pulling' }
try {
const gpuDevices = parseCDIDevices(createGPUDevices.value)
if (createGPUEnabled.value && gpuDevices.length === 0) {
throw new Error(t('bots.container.gpuDevicesRequired'))
}
const body: HandlersCreateContainerRequest = {
restore_data: createRestoreData.value,
}
const trimmedImage = createImage.value.trim()
if (trimmedImage) body.image = trimmedImage
if (createGPUEnabled.value || rememberedCreateGPU.value.exists) {
body.gpu = {
devices: createGPUEnabled.value ? gpuDevices : [],
}
}
const { dataRestored } = await createContainerSSE(body)
createRestoreData.value = false
createImage.value = ''
createGPUEnabled.value = false
createGPUDevices.value = ''
await loadContainerData(false)
await refetchBot()
toast.success(dataRestored
? t('bots.container.createRestoreSuccess')
: t('bots.container.createSuccess'))
@@ -567,6 +625,8 @@ const activeTab = useSyncedQueryParam('tab', 'overview')
watch(containerMissing, (missing) => {
if (!missing) {
createImagePrefilled.value = false
createGPUPrefilled.value = false
createAdvancedOpen.value = false
}
})
@@ -577,6 +637,15 @@ watch([containerMissing, rememberedCreateImage], ([missing, remembered]) => {
createImagePrefilled.value = true
}, { immediate: true })
watch([containerMissing, rememberedCreateGPU], ([missing, remembered]) => {
if (!missing || createGPUPrefilled.value) return
if (!remembered.exists) return
if (createGPUEnabled.value || createGPUDevices.value.trim()) return
createGPUEnabled.value = remembered.devices.length > 0
createGPUDevices.value = remembered.devices.join('\n')
createGPUPrefilled.value = true
}, { immediate: true })
watch([activeTab, botId], ([tab]) => {
if (!botId.value) return
if (tab === 'container') {
@@ -685,6 +754,59 @@ watch([activeTab, botId], ([tab]) => {
</p>
</div>
<Collapsible v-model:open="createAdvancedOpen">
<div class="rounded-md border">
<CollapsibleTrigger class="flex w-full items-center justify-between gap-3 px-3 py-2 text-left hover:bg-accent/40">
<div class="space-y-1">
<p class="text-xs font-medium">
{{ $t('bots.container.createAdvancedTitle') }}
</p>
<p class="text-xs text-muted-foreground">
{{ $t('bots.container.createAdvancedDescription') }}
</p>
</div>
<ChevronRight
class="size-4 shrink-0 text-muted-foreground transition-transform"
:class="{ 'rotate-90': createAdvancedOpen }"
/>
</CollapsibleTrigger>
<CollapsibleContent>
<div class="space-y-4 border-t px-3 py-3">
<div class="flex items-start justify-between gap-4 rounded-md border p-3">
<div class="space-y-1">
<Label>{{ $t('bots.container.createGpuLabel') }}</Label>
<p class="text-xs text-muted-foreground">
{{ $t('bots.container.createGpuDescription') }}
</p>
</div>
<Switch
:model-value="createGPUEnabled"
:disabled="containerBusy || botLifecyclePending"
@update:model-value="(value) => createGPUEnabled = !!value"
/>
</div>
<div
v-if="createGPUEnabled"
class="space-y-2"
>
<Label>{{ $t('bots.container.createGpuDevicesLabel') }}</Label>
<Textarea
v-model="createGPUDevices"
:placeholder="$t('bots.container.createGpuDevicesPlaceholder')"
:disabled="containerBusy || botLifecyclePending"
class="min-h-24 font-mono text-xs"
/>
<p class="text-xs text-muted-foreground">
{{ $t('bots.container.createGpuDevicesDescription') }}
</p>
</div>
</div>
</CollapsibleContent>
</div>
</Collapsible>
<div class="flex justify-end">
<Button
:disabled="containerBusy || botLifecyclePending"
@@ -784,6 +906,29 @@ watch([activeTab, botId], ([tab]) => {
{{ displayedContainerImage }}
</dd>
</div>
<div class="space-y-1 sm:col-span-2">
<dt class="text-muted-foreground">
{{ $t('bots.container.fields.cdiDevices') }}
</dt>
<dd
v-if="displayedCDIDevices.length === 0"
class="text-muted-foreground"
>
{{ $t('bots.container.cdiDevicesEmpty') }}
</dd>
<dd
v-else
class="space-y-1 font-mono text-xs"
>
<div
v-for="device in displayedCDIDevices"
:key="device"
class="break-all"
>
{{ device }}
</div>
</dd>
</div>
<div class="space-y-1 sm:col-span-2">
<dt class="text-muted-foreground">
{{ $t('bots.container.fields.containerPath') }}
@@ -813,6 +958,10 @@ watch([activeTab, botId], ([tab]) => {
</dl>
</div>
<div class="rounded-md border px-3 py-2 text-xs text-muted-foreground">
{{ $t('bots.container.gpuRecreateHint') }}
</div>
<div class="space-y-4 rounded-md border p-4">
<div class="space-y-1">
<h4 class="text-xs font-medium">
+3
View File
@@ -88,6 +88,9 @@ services:
- containerd_data:/var/lib/containerd
- server_cni_state:/var/lib/cni
- memoh_data:/opt/memoh/data
# Expose host CDI specs to the nested containerd used for bot workspaces.
- /etc/cdi:/etc/cdi:ro
- /var/run/cdi:/var/run/cdi:ro
# Toolkit: run ./docker/toolkit/install.sh once before first use
- ../.toolkit:/opt/memoh/runtime/toolkit
- ../docker/toolkit/bin:/opt/memoh/runtime/toolkit/bin
+3
View File
@@ -45,6 +45,9 @@ services:
- containerd_data:/var/lib/containerd
- server_cni_state:/var/lib/cni
- memoh_data:/opt/memoh/data
# Expose host CDI specs to the nested containerd used for bot workspaces.
- /etc/cdi:/etc/cdi:ro
- /var/run/cdi:/var/run/cdi:ro
- /etc/localtime:/etc/localtime:ro
ports:
- "8080:8080"
+46
View File
@@ -33,6 +33,52 @@ The **Container** tab displays real-time data about the bot's runtime:
- **Image**: The Docker/Containerd image used as the base.
- **Paths**: Host and container paths for data persistence.
- **Tasks**: Number of active background tasks running in the container.
- **CDI Devices**: The effective GPU CDI devices currently attached to the container, if any.
---
## Advanced: Provide CDI Devices
Memoh can provide host devices to a bot container through CDI (Container Device Interface). This is an advanced capability for users who want to expose host-managed devices, most commonly GPUs, to the container runtime.
In the Web UI, this capability is placed under **Advanced options** in the **Container** tab. It is optional and only needs to be configured when the bot must access CDI-backed devices from the host.
### Configure CDI Devices
1. Open the Bot's **Container** tab.
2. Click **Create** if the container does not exist, or recreate the container if you need to change GPU settings.
3. Expand **Advanced options**.
4. Enable **GPU**.
5. Enter one or more CDI device names in **CDI devices**.
You can enter CDI device names one per line or separated with commas. Common GPU-related examples:
- `nvidia.com/gpu=0`
- `nvidia.com/gpu=all`
- `amd.com/gpu=0`
- `amd.com/gpu=all`
### Host Requirements
Before configuring CDI devices in Memoh, the host machine must already provide working device drivers, vendor toolkit support where required, and valid CDI specs. In practice, this usually means:
- the host GPU works normally outside the container
- CDI spec files exist under `/etc/cdi` or `/var/run/cdi`
- the device name you enter in Memoh matches a real CDI device on the host
To discover the exact CDI device names exposed by the host, use the vendor tool on the host machine:
- NVIDIA: `nvidia-ctk cdi list`
- AMD: `amd-ctk cdi list`
If Memoh reports an error such as `unresolvable CDI devices`, the configured device name does not match any CDI device visible to the container runtime.
### Important Behavior
- CDI device settings are applied when the container is created. Updating the setting later requires recreating the container.
- Stopping and starting an existing container does not change its attached CDI devices.
- The container image still needs the appropriate user-space libraries and tools if you want to run CUDA or ROCm software inside the container.
- After creation, the **Container** tab shows the effective attached CDI devices for verification.
---
+7
View File
@@ -71,6 +71,7 @@ require (
github.com/distribution/reference v0.6.0 // indirect
github.com/emersion/go-message v0.18.2 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -104,6 +105,7 @@ require (
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/locker v1.0.1 // indirect
github.com/moby/sys/capability v0.4.0 // indirect
github.com/moby/sys/mountinfo v0.7.2 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/signal v0.7.1 // indirect
@@ -113,6 +115,7 @@ require (
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
github.com/oapi-codegen/runtime v1.1.2 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
github.com/opencontainers/selinux v1.13.1 // indirect
github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pkg/errors v0.9.1 // indirect
@@ -134,6 +137,7 @@ require (
go.uber.org/dig v1.19.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.1 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/net v0.50.0 // indirect
@@ -143,4 +147,7 @@ require (
golang.org/x/time v0.14.0 // indirect
golang.org/x/tools v0.42.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect
sigs.k8s.io/yaml v1.6.0 // indirect
tags.cncf.io/container-device-interface v1.1.0 // indirect
tags.cncf.io/container-device-interface/specs-go v1.1.0 // indirect
)
+14
View File
@@ -90,6 +90,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
@@ -236,6 +238,8 @@ github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3N
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc=
github.com/moby/sys/capability v0.4.0 h1:4D4mI6KlNtWMCM1Z/K0i7RV1FkX+DBDHKVJpCndZoHk=
github.com/moby/sys/capability v0.4.0/go.mod h1:4g9IK291rVkms3LKCDOoYlnV8xKwoDTpIrNEE35Wq0I=
github.com/moby/sys/mountinfo v0.7.2 h1:1shs6aH5s4o5H2zQLn796ADW1wMrIwHsyJ2v9KouLrg=
github.com/moby/sys/mountinfo v0.7.2/go.mod h1:1YOa8w8Ih7uW0wALDUgT1dTTSBrZ+HiBLGws92L2RU4=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
@@ -271,6 +275,8 @@ github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJw
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/opencontainers/runtime-spec v1.3.0 h1:YZupQUdctfhpZy3TM39nN9Ika5CBWT5diQ8ibYCRkxg=
github.com/opencontainers/runtime-spec v1.3.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 h1:tAKu3NkKWZYpqBSOJKwTxT1wIGueiF7gcmcNgr5pNTY=
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116/go.mod h1:DKDEfzxvRkoQ6n9TGhxQgg2IM1lY4aM0eaQP4e3oElw=
github.com/opencontainers/selinux v1.13.1 h1:A8nNeceYngH9Ow++M+VVEwJVpdFmrlxsN22F+ISDCJE=
github.com/opencontainers/selinux v1.13.1/go.mod h1:S10WXZ/osk2kWOYKy1x2f/eXF5ZHJoUs8UU/2caNRbg=
github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
@@ -364,6 +370,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -519,3 +527,9 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
tags.cncf.io/container-device-interface v1.1.0 h1:RnxNhxF1JOu6CJUVpetTYvrXHdxw9j9jFYgZpI+anSY=
tags.cncf.io/container-device-interface v1.1.0/go.mod h1:76Oj0Yqp9FwTx/pySDc8Bxjpg+VqXfDb50cKAXVJ34Q=
tags.cncf.io/container-device-interface/specs-go v1.1.0 h1:QRZVeAceQM+zTZe12eyfuJuuzp524EKYwhmvLd+h+yQ=
tags.cncf.io/container-device-interface/specs-go v1.1.0/go.mod h1:u86hoFWqnh3hWz3esofRFKbI261bUlvUfLKGrDhJkgQ=
+19
View File
@@ -13,15 +13,18 @@ import (
tasksv1 "github.com/containerd/containerd/api/services/tasks/v1"
tasktypes "github.com/containerd/containerd/api/types/task"
containerd "github.com/containerd/containerd/v2/client"
"github.com/containerd/containerd/v2/core/containers"
"github.com/containerd/containerd/v2/core/images"
"github.com/containerd/containerd/v2/core/remotes/docker"
"github.com/containerd/containerd/v2/core/snapshots"
cdispec "github.com/containerd/containerd/v2/pkg/cdi"
"github.com/containerd/containerd/v2/pkg/cio"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/containerd/v2/pkg/oci"
"github.com/containerd/errdefs"
"github.com/opencontainers/image-spec/identity"
"github.com/opencontainers/runtime-spec/specs-go"
cdi "tags.cncf.io/container-device-interface/pkg/cdi"
"github.com/memohai/memoh/internal/config"
)
@@ -248,10 +251,26 @@ func specOptsFromSpec(spec ContainerSpec) []oci.SpecOpts {
}
opts = append(opts, oci.WithMounts(mounts))
}
if len(spec.CDIDevices) > 0 {
opts = append(opts, withStaticCDIRegistry())
opts = append(opts, cdispec.WithCDIDevices(spec.CDIDevices...))
}
return opts
}
func withStaticCDIRegistry() oci.SpecOpts {
return func(_ context.Context, _ oci.Client, _ *containers.Container, _ *oci.Spec) error {
_ = cdi.Configure(cdi.WithAutoRefresh(false))
if err := cdi.Refresh(); err != nil {
// Invalid specs for other vendors should not block injection of a
// resolvable device set for the current container.
return nil
}
return nil
}
}
func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContainerRequest) (ContainerInfo, error) {
if req.ID == "" || req.ImageRef == "" {
return ContainerInfo{}, ErrInvalidArgument
+3
View File
@@ -161,6 +161,9 @@ func (s *AppleService) CreateContainer(ctx context.Context, req CreateContainerR
if req.ID == "" || req.ImageRef == "" {
return ContainerInfo{}, ErrInvalidArgument
}
if len(req.Spec.CDIDevices) > 0 {
return ContainerInfo{}, ErrNotSupported
}
if err := s.ensureHealthy(ctx); err != nil {
return ContainerInfo{}, err
}
+4 -1
View File
@@ -91,7 +91,10 @@ type ContainerSpec struct {
User string
Mounts []MountSpec
DNS []string
TTY bool
// CDIDevices contains fully-qualified CDI device names such as
// "nvidia.com/gpu=0" or "amd.com/gpu=0".
CDIDevices []string
TTY bool
}
type LayerStatus struct {
+44 -10
View File
@@ -41,19 +41,25 @@ type ContainerdHandler struct {
policyService *policy.Service
}
type ContainerGPURequest struct {
Devices []string `json:"devices,omitempty"`
}
type CreateContainerRequest struct {
Snapshotter string `json:"snapshotter,omitempty"`
RestoreData bool `json:"restore_data,omitempty"`
Image string `json:"image,omitempty"`
Snapshotter string `json:"snapshotter,omitempty"`
RestoreData bool `json:"restore_data,omitempty"`
Image string `json:"image,omitempty"`
GPU *ContainerGPURequest `json:"gpu,omitempty"`
}
type CreateContainerResponse struct {
ContainerID string `json:"container_id"`
Image string `json:"image"`
Snapshotter string `json:"snapshotter"`
Started bool `json:"started"`
DataRestored bool `json:"data_restored"`
HasPreservedData bool `json:"has_preserved_data"`
ContainerID string `json:"container_id"`
Image string `json:"image"`
Snapshotter string `json:"snapshotter"`
CDIDevices []string `json:"cdi_devices,omitempty"`
Started bool `json:"started"`
DataRestored bool `json:"data_restored"`
HasPreservedData bool `json:"has_preserved_data"`
}
// codesync(container-create-stream): keep these SSE payloads in sync with
@@ -92,6 +98,7 @@ type GetContainerResponse struct {
Status string `json:"status"`
Namespace string `json:"namespace"`
ContainerPath string `json:"container_path"`
CDIDevices []string `json:"cdi_devices,omitempty"`
TaskRunning bool `json:"task_running"`
HasPreservedData bool `json:"has_preserved_data"`
Legacy bool `json:"legacy"`
@@ -217,9 +224,18 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
slog.String("bot_id", botID), slog.Any("error", err))
return nil
}
gpu, err := h.manager.ResolveWorkspaceGPU(ctx, botID)
if err != nil {
h.logger.Error("resolve workspace gpu failed",
slog.String("bot_id", botID), slog.Any("error", err))
return nil
}
if imageOverride != "" {
image = config.NormalizeImageRef(imageOverride)
}
if req.GPU != nil {
gpu = workspace.WorkspaceGPUConfig{Devices: req.GPU.Devices}
}
snapshotter := strings.TrimSpace(req.Snapshotter)
if snapshotter == "" {
@@ -283,7 +299,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
send(createContainerRestoringEvent{Type: "restoring"})
}
if err := h.manager.StartWithResolvedImage(ctx, botID, image); err != nil {
if err := h.manager.StartWithResolvedConfig(ctx, botID, image, gpu); err != nil {
h.logger.Error("container start failed",
slog.String("bot_id", botID), slog.Any("error", err))
sendError("container start failed: " + err.Error())
@@ -293,6 +309,12 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
h.logger.Warn("remember workspace image failed",
slog.String("bot_id", botID), slog.String("image", image), slog.Any("error", err))
}
if req.GPU != nil {
if err := h.manager.RememberWorkspaceGPU(ctx, botID, gpu); err != nil {
h.logger.Warn("remember workspace gpu failed",
slog.String("bot_id", botID), slog.Any("error", err))
}
}
containerID, err := h.manager.ContainerID(ctx, botID)
if err != nil {
@@ -315,6 +337,16 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
h.manager.RecordContainerRunning(ctx, botID, containerID, image)
status, statusErr := h.manager.GetContainerInfo(ctx, botID)
if statusErr != nil {
h.logger.Warn("load container status after start failed",
slog.String("bot_id", botID), slog.Any("error", statusErr))
}
cdiDevices := gpu.Devices
if status != nil {
cdiDevices = status.CDIDevices
}
// Phase 3: Complete
send(createContainerCompleteEvent{
Type: "complete",
@@ -322,6 +354,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
ContainerID: containerID,
Image: image,
Snapshotter: snapshotter,
CDIDevices: cdiDevices,
Started: true,
DataRestored: dataRestored,
HasPreservedData: h.manager.HasPreservedData(botID),
@@ -357,6 +390,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error {
Status: status.Status,
Namespace: status.Namespace,
ContainerPath: status.ContainerPath,
CDIDevices: status.CDIDevices,
TaskRunning: status.TaskRunning,
HasPreservedData: status.HasPreservedData,
Legacy: status.Legacy,
+29
View File
@@ -0,0 +1,29 @@
package workspace
import (
"reflect"
"testing"
)
func TestWorkspaceCDIDevicesLabelRoundTrip(t *testing.T) {
t.Parallel()
devices := []string{" nvidia.com/gpu=0 ", "amd.com/gpu=1", "nvidia.com/gpu=0"}
value := workspaceCDIDevicesLabelValue(devices)
got := workspaceCDIDevicesFromLabels(map[string]string{
WorkspaceCDIDevicesLabelKey: value,
})
want := []string{"nvidia.com/gpu=0", "amd.com/gpu=1"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("expected devices %v, got %v", want, got)
}
}
func TestWorkspaceCDIDevicesFromLabelsIgnoresMissingValue(t *testing.T) {
t.Parallel()
if got := workspaceCDIDevicesFromLabels(nil); len(got) != 0 {
t.Fatalf("expected empty devices for nil labels, got %v", got)
}
}
+162
View File
@@ -16,8 +16,14 @@ import (
const (
workspaceMetadataKey = "workspace"
workspaceImageMetadataKey = "image"
workspaceGPUMetadataKey = "gpu"
workspaceGPUDevicesKey = "devices"
)
type WorkspaceGPUConfig struct {
Devices []string `json:"devices,omitempty"`
}
func decodeBotMetadata(payload []byte) (map[string]any, error) {
if len(payload) == 0 {
return map[string]any{}, nil
@@ -61,6 +67,54 @@ func workspaceImageFromMetadata(metadata map[string]any) string {
return strings.TrimSpace(image)
}
func normalizeWorkspaceGPUDevices(devices []string) []string {
if len(devices) == 0 {
return nil
}
seen := make(map[string]struct{}, len(devices))
normalized := make([]string, 0, len(devices))
for _, raw := range devices {
device := strings.TrimSpace(raw)
if device == "" {
continue
}
if _, ok := seen[device]; ok {
continue
}
seen[device] = struct{}{}
normalized = append(normalized, device)
}
return normalized
}
func workspaceGPUFromMetadata(metadata map[string]any) (WorkspaceGPUConfig, bool) {
section := workspaceSection(metadata)
raw, ok := section[workspaceGPUMetadataKey]
if !ok {
return WorkspaceGPUConfig{}, false
}
gpuSection, ok := raw.(map[string]any)
if !ok {
return WorkspaceGPUConfig{}, true
}
var devices []string
switch typed := gpuSection[workspaceGPUDevicesKey].(type) {
case []string:
devices = append(devices, typed...)
case []any:
for _, item := range typed {
if device, ok := item.(string); ok {
devices = append(devices, device)
}
}
}
return WorkspaceGPUConfig{Devices: normalizeWorkspaceGPUDevices(devices)}, true
}
func withWorkspaceImagePreference(metadata map[string]any, image string) map[string]any {
next := cloneAnyMap(metadata)
section := workspaceSection(next)
@@ -81,6 +135,28 @@ func withoutWorkspaceImagePreference(metadata map[string]any) map[string]any {
return next
}
func withWorkspaceGPUPreference(metadata map[string]any, gpu WorkspaceGPUConfig) map[string]any {
next := cloneAnyMap(metadata)
section := workspaceSection(next)
section[workspaceGPUMetadataKey] = map[string]any{
workspaceGPUDevicesKey: normalizeWorkspaceGPUDevices(gpu.Devices),
}
next[workspaceMetadataKey] = section
return next
}
func withoutWorkspaceGPUPreference(metadata map[string]any) map[string]any {
next := cloneAnyMap(metadata)
section := workspaceSection(next)
delete(section, workspaceGPUMetadataKey)
if len(section) == 0 {
delete(next, workspaceMetadataKey)
return next
}
next[workspaceMetadataKey] = section
return next
}
func (m *Manager) botWorkspaceImagePreference(ctx context.Context, botID string) (string, error) {
if m.queries == nil {
return "", nil
@@ -132,6 +208,7 @@ func (m *Manager) updateBotWorkspaceImagePreference(ctx context.Context, botID,
ID: botUUID,
DisplayName: row.DisplayName,
AvatarUrl: row.AvatarUrl,
Timezone: row.Timezone,
IsActive: row.IsActive,
Metadata: payload,
})
@@ -146,10 +223,82 @@ func (m *Manager) ClearWorkspaceImagePreference(ctx context.Context, botID strin
return m.updateBotWorkspaceImagePreference(ctx, botID, "", true)
}
func (m *Manager) botWorkspaceGPUPreference(ctx context.Context, botID string) (WorkspaceGPUConfig, bool, error) {
if m.queries == nil {
return WorkspaceGPUConfig{}, false, nil
}
botUUID, err := db.ParseUUID(botID)
if err != nil {
return WorkspaceGPUConfig{}, false, err
}
row, err := m.queries.GetBotByID(ctx, botUUID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return WorkspaceGPUConfig{}, false, nil
}
return WorkspaceGPUConfig{}, false, err
}
metadata, err := decodeBotMetadata(row.Metadata)
if err != nil {
return WorkspaceGPUConfig{}, false, err
}
gpu, ok := workspaceGPUFromMetadata(metadata)
return gpu, ok, nil
}
func (m *Manager) updateBotWorkspaceGPUPreference(ctx context.Context, botID string, gpu WorkspaceGPUConfig, clearPreference bool) error {
if m.queries == nil {
return nil
}
botUUID, err := db.ParseUUID(botID)
if err != nil {
return err
}
row, err := m.queries.GetBotByID(ctx, botUUID)
if err != nil {
return err
}
metadata, err := decodeBotMetadata(row.Metadata)
if err != nil {
return err
}
if clearPreference {
metadata = withoutWorkspaceGPUPreference(metadata)
} else {
metadata = withWorkspaceGPUPreference(metadata, gpu)
}
payload, err := json.Marshal(metadata)
if err != nil {
return err
}
_, err = m.queries.UpdateBotProfile(ctx, dbsqlc.UpdateBotProfileParams{
ID: botUUID,
DisplayName: row.DisplayName,
AvatarUrl: row.AvatarUrl,
Timezone: row.Timezone,
IsActive: row.IsActive,
Metadata: payload,
})
return err
}
func (m *Manager) RememberWorkspaceGPU(ctx context.Context, botID string, gpu WorkspaceGPUConfig) error {
gpu.Devices = normalizeWorkspaceGPUDevices(gpu.Devices)
return m.updateBotWorkspaceGPUPreference(ctx, botID, gpu, false)
}
func (m *Manager) ClearWorkspaceGPUPreference(ctx context.Context, botID string) error {
return m.updateBotWorkspaceGPUPreference(ctx, botID, WorkspaceGPUConfig{}, true)
}
func (m *Manager) ResolveWorkspaceImage(ctx context.Context, botID string) (string, error) {
return m.resolveWorkspaceImage(ctx, botID)
}
func (m *Manager) ResolveWorkspaceGPU(ctx context.Context, botID string) (WorkspaceGPUConfig, error) {
return m.resolveWorkspaceGPU(ctx, botID)
}
func (m *Manager) resolveWorkspaceImage(ctx context.Context, botID string) (string, error) {
if m.queries != nil {
pgBotID, err := db.ParseUUID(botID)
@@ -174,3 +323,16 @@ func (m *Manager) resolveWorkspaceImage(ctx context.Context, botID string) (stri
return m.imageRef(), nil
}
func (m *Manager) resolveWorkspaceGPU(ctx context.Context, botID string) (WorkspaceGPUConfig, error) {
preferredGPU, hasPreference, err := m.botWorkspaceGPUPreference(ctx, botID)
if err != nil {
return WorkspaceGPUConfig{}, err
}
if hasPreference {
preferredGPU.Devices = normalizeWorkspaceGPUDevices(preferredGPU.Devices)
return preferredGPU, nil
}
return WorkspaceGPUConfig{}, nil
}
@@ -51,3 +51,71 @@ func TestWithoutWorkspaceImagePreferenceRemovesOnlyImageKey(t *testing.T) {
t.Fatalf("expected unrelated workspace metadata to remain, got %#v", workspace)
}
}
func TestWorkspaceGPUMetadataRoundTrip(t *testing.T) {
t.Parallel()
metadata := map[string]any{
workspaceMetadataKey: map[string]any{
"keep": "value",
},
}
updated := withWorkspaceGPUPreference(metadata, WorkspaceGPUConfig{
Devices: []string{" nvidia.com/gpu=0 ", "amd.com/gpu=1", "nvidia.com/gpu=0"},
})
gpu, ok := workspaceGPUFromMetadata(updated)
if !ok {
t.Fatal("expected gpu preference to be present")
}
if got, want := gpu.Devices, []string{"nvidia.com/gpu=0", "amd.com/gpu=1"}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] {
t.Fatalf("expected normalized gpu devices %v, got %v", want, got)
}
workspace, ok := updated[workspaceMetadataKey].(map[string]any)
if !ok {
t.Fatal("expected workspace metadata section")
}
if workspace["keep"] != "value" {
t.Fatalf("expected existing workspace metadata to be preserved, got %#v", workspace)
}
}
func TestWorkspaceGPUExplicitDisableRemainsPresent(t *testing.T) {
t.Parallel()
metadata := withWorkspaceGPUPreference(map[string]any{}, WorkspaceGPUConfig{})
gpu, ok := workspaceGPUFromMetadata(metadata)
if !ok {
t.Fatal("expected gpu preference key to remain present")
}
if len(gpu.Devices) != 0 {
t.Fatalf("expected explicit disable with no devices, got %#v", gpu.Devices)
}
}
func TestWithoutWorkspaceGPUPreferenceRemovesOnlyGPUKey(t *testing.T) {
t.Parallel()
metadata := map[string]any{
workspaceMetadataKey: map[string]any{
workspaceGPUMetadataKey: map[string]any{
workspaceGPUDevicesKey: []any{"nvidia.com/gpu=all"},
},
"keep": true,
},
}
updated := withoutWorkspaceGPUPreference(metadata)
if _, ok := workspaceGPUFromMetadata(updated); ok {
t.Fatal("expected gpu preference to be cleared")
}
workspace, ok := updated[workspaceMetadataKey].(map[string]any)
if !ok {
t.Fatal("expected workspace metadata section to remain")
}
if workspace["keep"] != true {
t.Fatalf("expected unrelated workspace metadata to remain, got %#v", workspace)
}
}
+85 -29
View File
@@ -22,11 +22,12 @@ import (
)
const (
BotLabelKey = "memoh.bot_id"
WorkspaceLabelKey = "memoh.workspace"
WorkspaceLabelValue = "v3"
ContainerPrefix = "workspace-"
LegacyContainerPrefix = "mcp-"
BotLabelKey = "memoh.bot_id"
WorkspaceLabelKey = "memoh.workspace"
WorkspaceLabelValue = "v3"
WorkspaceCDIDevicesLabelKey = "memoh.workspace.cdi_devices"
ContainerPrefix = "workspace-"
LegacyContainerPrefix = "mcp-"
legacyGRPCPort = 9090
)
@@ -41,6 +42,7 @@ type ContainerStatus struct {
Status string `json:"status"`
Namespace string `json:"namespace"`
ContainerPath string `json:"container_path"`
CDIDevices []string `json:"cdi_devices,omitempty"`
TaskRunning bool `json:"task_running"`
HasPreservedData bool `json:"has_preserved_data"`
Legacy bool `json:"legacy"`
@@ -183,23 +185,39 @@ func (m *Manager) EnsureBot(ctx context.Context, botID, imageOverride string) er
if imageOverride != "" {
image = config.NormalizeImageRef(imageOverride)
}
return m.ensureBotWithImage(ctx, botID, image)
}
func (m *Manager) ensureBotWithImage(ctx context.Context, botID, image string) error {
if err := validateBotID(botID); err != nil {
return err
}
resolvPath, err := ctr.ResolveConfSource(m.dataRoot())
gpu, err := m.resolveWorkspaceGPU(ctx, botID)
if err != nil {
return err
}
return m.ensureBotWithImage(ctx, botID, image, gpu)
}
func workspaceCDIDevicesLabelValue(devices []string) string {
devices = normalizeWorkspaceGPUDevices(devices)
return strings.Join(devices, ",")
}
func workspaceCDIDevicesFromLabels(labels map[string]string) []string {
if len(labels) == 0 {
return nil
}
value := strings.TrimSpace(labels[WorkspaceCDIDevicesLabelKey])
if value == "" {
return nil
}
return normalizeWorkspaceGPUDevices(strings.Split(value, ","))
}
func (m *Manager) buildWorkspaceContainerSpec(botID string, gpu WorkspaceGPUConfig) (ctr.ContainerSpec, error) {
resolvPath, err := ctr.ResolveConfSource(m.dataRoot())
if err != nil {
return ctr.ContainerSpec{}, err
}
runtimeDir := m.cfg.RuntimePath()
sockDir := m.socketDir(botID)
if err := os.MkdirAll(sockDir, 0o750); err != nil {
return fmt.Errorf("create socket dir: %w", err)
return ctr.ContainerSpec{}, fmt.Errorf("create socket dir: %w", err)
}
mounts := []ctr.MountSpec{
@@ -229,19 +247,37 @@ func (m *Manager) ensureBotWithImage(ctx context.Context, botID, image string) e
env = append(env, tzEnv...)
env = append(env, "BRIDGE_SOCKET_PATH=/run/memoh/bridge.sock")
return ctr.ContainerSpec{
Cmd: []string{"/opt/memoh/bridge"},
Mounts: mounts,
Env: env,
CDIDevices: normalizeWorkspaceGPUDevices(gpu.Devices),
}, nil
}
func (m *Manager) ensureBotWithImage(ctx context.Context, botID, image string, gpu WorkspaceGPUConfig) error {
if err := validateBotID(botID); err != nil {
return err
}
spec, err := m.buildWorkspaceContainerSpec(botID, gpu)
if err != nil {
return err
}
labels := map[string]string{
BotLabelKey: botID,
WorkspaceLabelKey: WorkspaceLabelValue,
}
if value := workspaceCDIDevicesLabelValue(gpu.Devices); value != "" {
labels[WorkspaceCDIDevicesLabelKey] = value
}
_, err = m.service.CreateContainer(ctx, ctr.CreateContainerRequest{
ID: ContainerPrefix + botID,
ImageRef: image,
Snapshotter: m.cfg.Snapshotter,
Labels: map[string]string{
BotLabelKey: botID,
WorkspaceLabelKey: WorkspaceLabelValue,
},
Spec: ctr.ContainerSpec{
Cmd: []string{"/opt/memoh/bridge"},
Mounts: mounts,
Env: env,
},
Labels: labels,
Spec: spec,
})
if err == nil {
return nil
@@ -275,7 +311,11 @@ func (m *Manager) Start(ctx context.Context, botID string) error {
if err != nil {
return err
}
return m.startWithResolvedImage(ctx, botID, image)
gpu, err := m.resolveWorkspaceGPU(ctx, botID)
if err != nil {
return err
}
return m.startWithResolvedConfig(ctx, botID, image, gpu)
}
// StartWithImage creates and starts the MCP container for a bot.
@@ -286,7 +326,11 @@ func (m *Manager) StartWithImage(ctx context.Context, botID, imageOverride strin
if image == "" {
return m.Start(ctx, botID)
}
return m.startWithResolvedImage(ctx, botID, config.NormalizeImageRef(image))
gpu, err := m.resolveWorkspaceGPU(ctx, botID)
if err != nil {
return err
}
return m.startWithResolvedConfig(ctx, botID, config.NormalizeImageRef(image), gpu)
}
// StartWithResolvedImage creates and starts the workspace container for a bot
@@ -296,10 +340,22 @@ func (m *Manager) StartWithResolvedImage(ctx context.Context, botID, image strin
if image == "" {
return errors.New("image is required")
}
return m.startWithResolvedImage(ctx, botID, image)
gpu, err := m.resolveWorkspaceGPU(ctx, botID)
if err != nil {
return err
}
return m.startWithResolvedConfig(ctx, botID, image, gpu)
}
func (m *Manager) startWithResolvedImage(ctx context.Context, botID, image string) error {
func (m *Manager) StartWithResolvedConfig(ctx context.Context, botID, image string, gpu WorkspaceGPUConfig) error {
image = strings.TrimSpace(image)
if image == "" {
return errors.New("image is required")
}
return m.startWithResolvedConfig(ctx, botID, image, gpu)
}
func (m *Manager) startWithResolvedConfig(ctx context.Context, botID, image string, gpu WorkspaceGPUConfig) error {
containerID := m.resolveContainerID(ctx, botID)
// Before creating a new container, check for an orphaned snapshot
@@ -311,7 +367,7 @@ func (m *Manager) startWithResolvedImage(ctx context.Context, botID, image strin
m.recoverOrphanedSnapshot(ctx, botID)
}
if err := m.ensureBotWithImage(ctx, botID, image); err != nil {
if err := m.ensureBotWithImage(ctx, botID, image, gpu); err != nil {
return err
}
+7 -1
View File
@@ -202,6 +202,10 @@ func (m *Manager) GetContainerInfo(ctx context.Context, botID string) (*Containe
if parseErr == nil {
row, dbErr := m.queries.GetContainerByBotID(ctx, pgBotID)
if dbErr == nil {
cdiDevices := []string(nil)
if liveInfo, liveErr := m.service.GetContainer(ctx, row.ContainerID); liveErr == nil {
cdiDevices = workspaceCDIDevicesFromLabels(liveInfo.Labels)
}
createdAt := time.Time{}
if row.CreatedAt.Valid {
createdAt = row.CreatedAt.Time
@@ -216,6 +220,7 @@ func (m *Manager) GetContainerInfo(ctx context.Context, botID string) (*Containe
Status: row.Status,
Namespace: row.Namespace,
ContainerPath: row.ContainerPath,
CDIDevices: cdiDevices,
TaskRunning: m.isTaskRunning(ctx, row.ContainerID),
HasPreservedData: m.HasPreservedData(botID),
Legacy: m.IsLegacyContainer(ctx, row.ContainerID),
@@ -242,6 +247,7 @@ func (m *Manager) GetContainerInfo(ctx context.Context, botID string) (*Containe
Image: info.Image,
Status: "unknown",
Namespace: m.namespace,
CDIDevices: workspaceCDIDevicesFromLabels(info.Labels),
TaskRunning: m.isTaskRunning(ctx, containerID),
HasPreservedData: m.HasPreservedData(botID),
Legacy: m.IsLegacyContainer(ctx, containerID),
@@ -270,7 +276,7 @@ func (m *Manager) SetupBotContainer(ctx context.Context, botID string) error {
return err
}
if err := m.startWithResolvedImage(ctx, botID, image); err != nil {
if err := m.StartWithResolvedImage(ctx, botID, image); err != nil {
m.logger.Error("setup bot container: start failed",
slog.String("bot_id", botID),
slog.Any("error", err))
+9 -41
View File
@@ -370,7 +370,7 @@ func (m *Manager) replaceContainerSnapshot(ctx context.Context, botID, container
if err := m.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{CleanupSnapshot: false}); err != nil {
return err
}
spec, err := m.buildVersionSpec(botID)
spec, err := m.buildVersionSpec(ctx, botID, workspaceCDIDevicesFromLabels(info.Labels))
if err != nil {
return err
}
@@ -402,47 +402,15 @@ func (m *Manager) replaceContainerSnapshot(ctx context.Context, botID, container
return nil
}
func (m *Manager) buildVersionSpec(botID string) (ctr.ContainerSpec, error) {
resolvPath, err := ctr.ResolveConfSource(m.dataRoot())
if err != nil {
return ctr.ContainerSpec{}, err
func (m *Manager) buildVersionSpec(ctx context.Context, botID string, cdiDevices []string) (ctr.ContainerSpec, error) {
if len(cdiDevices) == 0 {
gpu, err := m.resolveWorkspaceGPU(ctx, botID)
if err != nil {
return ctr.ContainerSpec{}, err
}
cdiDevices = gpu.Devices
}
runtimeDir := m.cfg.RuntimePath()
sockDir := m.socketDir(botID)
mounts := []ctr.MountSpec{
{
Destination: "/etc/resolv.conf",
Type: "bind",
Source: resolvPath,
Options: []string{"rbind", "ro"},
},
{
Destination: "/opt/memoh",
Type: "bind",
Source: runtimeDir,
Options: []string{"rbind", "ro"},
},
{
Destination: "/run/memoh",
Type: "bind",
Source: sockDir,
Options: []string{"rbind", "rw"},
},
}
tzMounts, tzEnv := ctr.TimezoneSpec()
mounts = append(mounts, tzMounts...)
env := make([]string, 0, len(tzEnv)+1)
env = append(env, tzEnv...)
env = append(env, "BRIDGE_SOCKET_PATH=/run/memoh/bridge.sock")
return ctr.ContainerSpec{
Cmd: []string{"/opt/memoh/bridge"},
Mounts: mounts,
Env: env,
}, nil
return m.buildWorkspaceContainerSpec(botID, WorkspaceGPUConfig{Devices: cdiDevices})
}
func (m *Manager) safeStopTask(ctx context.Context, containerID string) error {
+7
View File
@@ -738,18 +738,24 @@ export type HandlersChannelMeta = {
user_config_schema?: ChannelConfigSchema;
};
export type HandlersContainerGpuRequest = {
devices?: Array<string>;
};
export type HandlersContextUsage = {
context_window?: number;
used_tokens?: number;
};
export type HandlersCreateContainerRequest = {
gpu?: HandlersContainerGpuRequest;
image?: string;
restore_data?: boolean;
snapshotter?: string;
};
export type HandlersCreateContainerResponse = {
cdi_devices?: Array<string>;
container_id?: string;
data_restored?: boolean;
has_preserved_data?: boolean;
@@ -830,6 +836,7 @@ export type HandlersFsWriteRequest = {
};
export type HandlersGetContainerResponse = {
cdi_devices?: Array<string>;
container_id?: string;
container_path?: string;
created_at?: string;
+26
View File
@@ -10473,6 +10473,17 @@ const docTemplate = `{
}
}
},
"handlers.ContainerGPURequest": {
"type": "object",
"properties": {
"devices": {
"type": "array",
"items": {
"type": "string"
}
}
}
},
"handlers.ContextUsage": {
"type": "object",
"properties": {
@@ -10487,6 +10498,9 @@ const docTemplate = `{
"handlers.CreateContainerRequest": {
"type": "object",
"properties": {
"gpu": {
"$ref": "#/definitions/handlers.ContainerGPURequest"
},
"image": {
"type": "string"
},
@@ -10501,6 +10515,12 @@ const docTemplate = `{
"handlers.CreateContainerResponse": {
"type": "object",
"properties": {
"cdi_devices": {
"type": "array",
"items": {
"type": "string"
}
},
"container_id": {
"type": "string"
},
@@ -10692,6 +10712,12 @@ const docTemplate = `{
"handlers.GetContainerResponse": {
"type": "object",
"properties": {
"cdi_devices": {
"type": "array",
"items": {
"type": "string"
}
},
"container_id": {
"type": "string"
},
+26
View File
@@ -10464,6 +10464,17 @@
}
}
},
"handlers.ContainerGPURequest": {
"type": "object",
"properties": {
"devices": {
"type": "array",
"items": {
"type": "string"
}
}
}
},
"handlers.ContextUsage": {
"type": "object",
"properties": {
@@ -10478,6 +10489,9 @@
"handlers.CreateContainerRequest": {
"type": "object",
"properties": {
"gpu": {
"$ref": "#/definitions/handlers.ContainerGPURequest"
},
"image": {
"type": "string"
},
@@ -10492,6 +10506,12 @@
"handlers.CreateContainerResponse": {
"type": "object",
"properties": {
"cdi_devices": {
"type": "array",
"items": {
"type": "string"
}
},
"container_id": {
"type": "string"
},
@@ -10683,6 +10703,12 @@
"handlers.GetContainerResponse": {
"type": "object",
"properties": {
"cdi_devices": {
"type": "array",
"items": {
"type": "string"
}
},
"container_id": {
"type": "string"
},
+17
View File
@@ -1222,6 +1222,13 @@ definitions:
user_config_schema:
$ref: '#/definitions/channel.ConfigSchema'
type: object
handlers.ContainerGPURequest:
properties:
devices:
items:
type: string
type: array
type: object
handlers.ContextUsage:
properties:
context_window:
@@ -1231,6 +1238,8 @@ definitions:
type: object
handlers.CreateContainerRequest:
properties:
gpu:
$ref: '#/definitions/handlers.ContainerGPURequest'
image:
type: string
restore_data:
@@ -1240,6 +1249,10 @@ definitions:
type: object
handlers.CreateContainerResponse:
properties:
cdi_devices:
items:
type: string
type: array
container_id:
type: string
data_restored:
@@ -1363,6 +1376,10 @@ definitions:
type: object
handlers.GetContainerResponse:
properties:
cdi_devices:
items:
type: string
type: array
container_id:
type: string
container_path: