diff --git a/apps/web/src/i18n/locales/en.json b/apps/web/src/i18n/locales/en.json
index 09050332..22ee44f0 100644
--- a/apps/web/src/i18n/locales/en.json
+++ b/apps/web/src/i18n/locales/en.json
@@ -748,6 +748,16 @@
"createRestoreDataDescription": "If a previously exported backup or legacy bind-mounted data exists, it will be restored into `/data` after the container is created.",
"createImageLabel": "Base image",
"createImageDescription": "Docker image to use as the container base (e.g. debian:bookworm-slim, alpine:latest, ubuntu:24.04). Leave empty for the default.",
+ "createAdvancedTitle": "Advanced options",
+ "createAdvancedDescription": "Configure optional GPU access and raw CDI device mappings for this container.",
+ "createGpuLabel": "Enable GPU",
+ "createGpuDescription": "Attach host GPU access to the new container.",
+ "createGpuDevicesLabel": "CDI devices",
+ "createGpuDevicesDescription": "Enter one CDI device per line or separate them with commas, for example `nvidia.com/gpu=0` or `amd.com/gpu=0`. Turning GPU off explicitly clears the saved GPU preference.",
+ "createGpuDevicesPlaceholder": "nvidia.com/gpu=0\namd.com/gpu=0",
+ "gpuDevicesRequired": "At least one CDI device is required when GPU is enabled.",
+ "cdiDevicesEmpty": "No GPU attached",
+ "gpuRecreateHint": "Changing GPU settings requires recreating the container. A simple start or stop will not change the devices already attached.",
"deleteConfirm": "Are you sure you want to permanently delete this container? Unpreserved data cannot be recovered.",
"deletePreserveConfirm": "Are you sure you want to export `/data` and then delete this container?",
"restoreConfirm": "Are you sure you want to restore preserved data into this container's `/data`?",
@@ -797,6 +807,7 @@
"task": "Task",
"namespace": "Namespace",
"image": "Image",
+ "cdiDevices": "CDI Devices",
"hostPath": "Host Path",
"containerPath": "Container Path",
"preservedData": "Preserved Data",
diff --git a/apps/web/src/i18n/locales/zh.json b/apps/web/src/i18n/locales/zh.json
index d59da1cb..bed56e6c 100644
--- a/apps/web/src/i18n/locales/zh.json
+++ b/apps/web/src/i18n/locales/zh.json
@@ -744,6 +744,16 @@
"createRestoreDataDescription": "如果存在之前导出的备份或旧版 bind mount 数据,将在容器创建后恢复到 `/data`。",
"createImageLabel": "基础镜像",
"createImageDescription": "作为容器基础环境的 Docker 镜像(如 debian:bookworm-slim、alpine:latest、ubuntu:24.04)。留空则使用默认镜像。",
+ "createAdvancedTitle": "高级选项",
+ "createAdvancedDescription": "配置该容器的可选 GPU 访问能力与原始 CDI 设备映射。",
+ "createGpuLabel": "启用 GPU",
+ "createGpuDescription": "为新容器开启宿主机 GPU 访问。",
+ "createGpuDevicesLabel": "CDI 设备",
+ "createGpuDevicesDescription": "每行或用逗号填写一个 CDI 设备名,例如 `nvidia.com/gpu=0` 或 `amd.com/gpu=0`。关闭 GPU 后会显式清空已保存的 GPU 偏好。",
+ "createGpuDevicesPlaceholder": "nvidia.com/gpu=0\namd.com/gpu=0",
+ "gpuDevicesRequired": "已启用 GPU 时,至少需要填写一个 CDI 设备名。",
+ "cdiDevicesEmpty": "未附加 GPU",
+ "gpuRecreateHint": "GPU 配置变更需要重建容器后才会生效,单纯启动或停止不会更新当前已附加的设备。",
"deleteConfirm": "确定要彻底删除这个容器吗?未保留的数据将无法恢复。",
"deletePreserveConfirm": "确定要先导出 `/data` 再删除这个容器吗?",
"restoreConfirm": "确定要将已保留的数据恢复到当前容器的 `/data` 吗?",
@@ -793,6 +803,7 @@
"task": "任务状态",
"namespace": "命名空间",
"image": "镜像",
+ "cdiDevices": "CDI 设备",
"hostPath": "主机路径",
"containerPath": "容器路径",
"preservedData": "保留数据",
diff --git a/apps/web/src/pages/bots/components/bot-container.vue b/apps/web/src/pages/bots/components/bot-container.vue
index 8d49da2c..02209fbf 100644
--- a/apps/web/src/pages/bots/components/bot-container.vue
+++ b/apps/web/src/pages/bots/components/bot-container.vue
@@ -4,6 +4,7 @@ import { toast } from 'vue-sonner'
import { useI18n } from 'vue-i18n'
import { useRoute } from 'vue-router'
import { useQuery } from '@pinia/colada'
+import { ChevronRight } from 'lucide-vue-next'
import {
deleteBotsByBotIdContainer,
getBotsByBotIdContainer,
@@ -25,7 +26,7 @@ import {
type ContainerCreateLayerStatus,
type ContainerCreateStreamEvent,
} from '@/composables/api/useContainerStream'
-import { Button, Input, Label, Separator, Spinner, Switch } from '@memohai/ui'
+import { Button, Collapsible, CollapsibleContent, CollapsibleTrigger, Input, Label, Separator, Spinner, Switch, Textarea } from '@memohai/ui'
import ConfirmPopover from '@/components/confirm-popover/index.vue'
import ContainerCreateProgress from './container-create-progress.vue'
import { useSyncedQueryParam } from '@/composables/useSyncedQueryParam'
@@ -59,6 +60,10 @@ const rollbackVersion = ref(null)
const createRestoreData = ref(false)
const createImage = ref('')
const createImagePrefilled = ref(false)
+const createGPUEnabled = ref(false)
+const createGPUDevices = ref('')
+const createGPUPrefilled = ref(false)
+const createAdvancedOpen = ref(false)
const newSnapshotName = ref('')
const importInputRef = ref(null)
@@ -178,7 +183,7 @@ async function handleRefreshContainer() {
await runContainerAction('refresh', () => loadContainerData(false))
}
-const { data: bot } = useQuery({
+const { data: bot, refetch: refetchBot } = useQuery({
key: () => ['bot', botId.value],
query: async () => {
const { data } = await getBotsById({ path: { id: botId.value }, throwOnError: true })
@@ -194,8 +199,48 @@ function rememberedWorkspaceImage(metadata: Record | undefined)
return typeof image === 'string' ? shortenImageRef(image) : ''
}
+type RememberedWorkspaceGPU = {
+ exists: boolean
+ devices: string[]
+}
+
+function rememberedWorkspaceGPU(metadata: Record | undefined): RememberedWorkspaceGPU {
+ const workspace = metadata?.workspace
+ if (!workspace || typeof workspace !== 'object' || Array.isArray(workspace)) {
+ return { exists: false, devices: [] }
+ }
+
+ const workspaceRecord = workspace as Record
+ if (!Object.prototype.hasOwnProperty.call(workspaceRecord, 'gpu')) {
+ return { exists: false, devices: [] }
+ }
+
+ const gpu = workspaceRecord.gpu
+ if (!gpu || typeof gpu !== 'object' || Array.isArray(gpu)) {
+ return { exists: true, devices: [] }
+ }
+
+ const rawDevices = (gpu as Record).devices
+ const devices = Array.isArray(rawDevices)
+ ? rawDevices.filter((value): value is string => typeof value === 'string').map(value => value.trim()).filter(Boolean)
+ : []
+
+ return { exists: true, devices: [...new Set(devices)] }
+}
+
+function parseCDIDevices(value: string): string[] {
+ return [...new Set(
+ value
+ .split(/[\n,]/)
+ .map(item => item.trim())
+ .filter(Boolean),
+ )]
+}
+
const rememberedCreateImage = computed(() => rememberedWorkspaceImage(bot.value?.metadata as Record | undefined))
+const rememberedCreateGPU = computed(() => rememberedWorkspaceGPU(bot.value?.metadata as Record | undefined))
const displayedContainerImage = computed(() => shortenImageRef(containerInfo.value?.image))
+const displayedCDIDevices = computed(() => containerInfo.value?.cdi_devices ?? [])
const { isPending: botLifecyclePending } = useBotStatusMeta(bot, t)
@@ -248,16 +293,29 @@ async function handleCreateContainer() {
containerAction.value = 'create'
createProgress.value = { phase: 'pulling' }
try {
+ const gpuDevices = parseCDIDevices(createGPUDevices.value)
+ if (createGPUEnabled.value && gpuDevices.length === 0) {
+ throw new Error(t('bots.container.gpuDevicesRequired'))
+ }
+
const body: HandlersCreateContainerRequest = {
restore_data: createRestoreData.value,
}
const trimmedImage = createImage.value.trim()
if (trimmedImage) body.image = trimmedImage
+ if (createGPUEnabled.value || rememberedCreateGPU.value.exists) {
+ body.gpu = {
+ devices: createGPUEnabled.value ? gpuDevices : [],
+ }
+ }
const { dataRestored } = await createContainerSSE(body)
createRestoreData.value = false
createImage.value = ''
+ createGPUEnabled.value = false
+ createGPUDevices.value = ''
await loadContainerData(false)
+ await refetchBot()
toast.success(dataRestored
? t('bots.container.createRestoreSuccess')
: t('bots.container.createSuccess'))
@@ -567,6 +625,8 @@ const activeTab = useSyncedQueryParam('tab', 'overview')
watch(containerMissing, (missing) => {
if (!missing) {
createImagePrefilled.value = false
+ createGPUPrefilled.value = false
+ createAdvancedOpen.value = false
}
})
@@ -577,6 +637,15 @@ watch([containerMissing, rememberedCreateImage], ([missing, remembered]) => {
createImagePrefilled.value = true
}, { immediate: true })
+watch([containerMissing, rememberedCreateGPU], ([missing, remembered]) => {
+ if (!missing || createGPUPrefilled.value) return
+ if (!remembered.exists) return
+ if (createGPUEnabled.value || createGPUDevices.value.trim()) return
+ createGPUEnabled.value = remembered.devices.length > 0
+ createGPUDevices.value = remembered.devices.join('\n')
+ createGPUPrefilled.value = true
+}, { immediate: true })
+
watch([activeTab, botId], ([tab]) => {
if (!botId.value) return
if (tab === 'container') {
@@ -685,6 +754,59 @@ watch([activeTab, botId], ([tab]) => {
+
diff --git a/devenv/docker-compose.yml b/devenv/docker-compose.yml
index 72bab321..84be0f10 100644
--- a/devenv/docker-compose.yml
+++ b/devenv/docker-compose.yml
@@ -88,6 +88,9 @@ services:
- containerd_data:/var/lib/containerd
- server_cni_state:/var/lib/cni
- memoh_data:/opt/memoh/data
+ # Expose host CDI specs to the nested containerd used for bot workspaces.
+ - /etc/cdi:/etc/cdi:ro
+ - /var/run/cdi:/var/run/cdi:ro
# Toolkit: run ./docker/toolkit/install.sh once before first use
- ../.toolkit:/opt/memoh/runtime/toolkit
- ../docker/toolkit/bin:/opt/memoh/runtime/toolkit/bin
diff --git a/docker-compose.yml b/docker-compose.yml
index b6624e5e..3fa6411f 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -45,6 +45,9 @@ services:
- containerd_data:/var/lib/containerd
- server_cni_state:/var/lib/cni
- memoh_data:/opt/memoh/data
+ # Expose host CDI specs to the nested containerd used for bot workspaces.
+ - /etc/cdi:/etc/cdi:ro
+ - /var/run/cdi:/var/run/cdi:ro
- /etc/localtime:/etc/localtime:ro
ports:
- "8080:8080"
diff --git a/docs/docs/getting-started/container.md b/docs/docs/getting-started/container.md
index 3e20b67d..60c932e4 100644
--- a/docs/docs/getting-started/container.md
+++ b/docs/docs/getting-started/container.md
@@ -33,6 +33,52 @@ The **Container** tab displays real-time data about the bot's runtime:
- **Image**: The Docker/Containerd image used as the base.
- **Paths**: Host and container paths for data persistence.
- **Tasks**: Number of active background tasks running in the container.
+- **CDI Devices**: The effective GPU CDI devices currently attached to the container, if any.
+
+---
+
+## Advanced: Provide CDI Devices
+
+Memoh can provide host devices to a bot container through CDI (Container Device Interface). This is an advanced capability for users who want to expose host-managed devices, most commonly GPUs, to the container runtime.
+
+In the Web UI, this capability is placed under **Advanced options** in the **Container** tab. It is optional and only needs to be configured when the bot must access CDI-backed devices from the host.
+
+### Configure CDI Devices
+
+1. Open the Bot's **Container** tab.
+2. Click **Create** if the container does not exist, or recreate the container if you need to change GPU settings.
+3. Expand **Advanced options**.
+4. Enable **GPU**.
+5. Enter one or more CDI device names in **CDI devices**.
+
+You can enter CDI device names one per line or separated with commas. Common GPU-related examples:
+
+- `nvidia.com/gpu=0`
+- `nvidia.com/gpu=all`
+- `amd.com/gpu=0`
+- `amd.com/gpu=all`
+
+### Host Requirements
+
+Before configuring CDI devices in Memoh, the host machine must already provide working device drivers, vendor toolkit support where required, and valid CDI specs. In practice, this usually means:
+
+- the host GPU works normally outside the container
+- CDI spec files exist under `/etc/cdi` or `/var/run/cdi`
+- the device name you enter in Memoh matches a real CDI device on the host
+
+To discover the exact CDI device names exposed by the host, use the vendor tool on the host machine:
+
+- NVIDIA: `nvidia-ctk cdi list`
+- AMD: `amd-ctk cdi list`
+
+If Memoh reports an error such as `unresolvable CDI devices`, the configured device name does not match any CDI device visible to the container runtime.
+
+### Important Behavior
+
+- CDI device settings are applied when the container is created. Updating the setting later requires recreating the container.
+- Stopping and starting an existing container does not change its attached CDI devices.
+- The container image still needs the appropriate user-space libraries and tools if you want to run CUDA or ROCm software inside the container.
+- After creation, the **Container** tab shows the effective attached CDI devices for verification.
---
diff --git a/go.mod b/go.mod
index 9c89367b..9112f352 100644
--- a/go.mod
+++ b/go.mod
@@ -71,6 +71,7 @@ require (
github.com/distribution/reference v0.6.0 // indirect
github.com/emersion/go-message v0.18.2 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
+ github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -104,6 +105,7 @@ require (
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/locker v1.0.1 // indirect
+ github.com/moby/sys/capability v0.4.0 // indirect
github.com/moby/sys/mountinfo v0.7.2 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/signal v0.7.1 // indirect
@@ -113,6 +115,7 @@ require (
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
github.com/oapi-codegen/runtime v1.1.2 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
+ github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
github.com/opencontainers/selinux v1.13.1 // indirect
github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pkg/errors v0.9.1 // indirect
@@ -134,6 +137,7 @@ require (
go.uber.org/dig v1.19.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.1 // indirect
+ go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/net v0.50.0 // indirect
@@ -143,4 +147,7 @@ require (
golang.org/x/time v0.14.0 // indirect
golang.org/x/tools v0.42.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect
+ sigs.k8s.io/yaml v1.6.0 // indirect
+ tags.cncf.io/container-device-interface v1.1.0 // indirect
+ tags.cncf.io/container-device-interface/specs-go v1.1.0 // indirect
)
diff --git a/go.sum b/go.sum
index bc746f2a..5445a5f5 100644
--- a/go.sum
+++ b/go.sum
@@ -90,6 +90,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
+github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
+github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
@@ -236,6 +238,8 @@ github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3N
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc=
+github.com/moby/sys/capability v0.4.0 h1:4D4mI6KlNtWMCM1Z/K0i7RV1FkX+DBDHKVJpCndZoHk=
+github.com/moby/sys/capability v0.4.0/go.mod h1:4g9IK291rVkms3LKCDOoYlnV8xKwoDTpIrNEE35Wq0I=
github.com/moby/sys/mountinfo v0.7.2 h1:1shs6aH5s4o5H2zQLn796ADW1wMrIwHsyJ2v9KouLrg=
github.com/moby/sys/mountinfo v0.7.2/go.mod h1:1YOa8w8Ih7uW0wALDUgT1dTTSBrZ+HiBLGws92L2RU4=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
@@ -271,6 +275,8 @@ github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJw
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/opencontainers/runtime-spec v1.3.0 h1:YZupQUdctfhpZy3TM39nN9Ika5CBWT5diQ8ibYCRkxg=
github.com/opencontainers/runtime-spec v1.3.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 h1:tAKu3NkKWZYpqBSOJKwTxT1wIGueiF7gcmcNgr5pNTY=
+github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116/go.mod h1:DKDEfzxvRkoQ6n9TGhxQgg2IM1lY4aM0eaQP4e3oElw=
github.com/opencontainers/selinux v1.13.1 h1:A8nNeceYngH9Ow++M+VVEwJVpdFmrlxsN22F+ISDCJE=
github.com/opencontainers/selinux v1.13.1/go.mod h1:S10WXZ/osk2kWOYKy1x2f/eXF5ZHJoUs8UU/2caNRbg=
github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
@@ -364,6 +370,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
+go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
+go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -519,3 +527,9 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
+sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
+tags.cncf.io/container-device-interface v1.1.0 h1:RnxNhxF1JOu6CJUVpetTYvrXHdxw9j9jFYgZpI+anSY=
+tags.cncf.io/container-device-interface v1.1.0/go.mod h1:76Oj0Yqp9FwTx/pySDc8Bxjpg+VqXfDb50cKAXVJ34Q=
+tags.cncf.io/container-device-interface/specs-go v1.1.0 h1:QRZVeAceQM+zTZe12eyfuJuuzp524EKYwhmvLd+h+yQ=
+tags.cncf.io/container-device-interface/specs-go v1.1.0/go.mod h1:u86hoFWqnh3hWz3esofRFKbI261bUlvUfLKGrDhJkgQ=
diff --git a/internal/containerd/service.go b/internal/containerd/service.go
index be15e07b..7cd2c69f 100644
--- a/internal/containerd/service.go
+++ b/internal/containerd/service.go
@@ -13,15 +13,18 @@ import (
tasksv1 "github.com/containerd/containerd/api/services/tasks/v1"
tasktypes "github.com/containerd/containerd/api/types/task"
containerd "github.com/containerd/containerd/v2/client"
+ "github.com/containerd/containerd/v2/core/containers"
"github.com/containerd/containerd/v2/core/images"
"github.com/containerd/containerd/v2/core/remotes/docker"
"github.com/containerd/containerd/v2/core/snapshots"
+ cdispec "github.com/containerd/containerd/v2/pkg/cdi"
"github.com/containerd/containerd/v2/pkg/cio"
"github.com/containerd/containerd/v2/pkg/namespaces"
"github.com/containerd/containerd/v2/pkg/oci"
"github.com/containerd/errdefs"
"github.com/opencontainers/image-spec/identity"
"github.com/opencontainers/runtime-spec/specs-go"
+ cdi "tags.cncf.io/container-device-interface/pkg/cdi"
"github.com/memohai/memoh/internal/config"
)
@@ -248,10 +251,26 @@ func specOptsFromSpec(spec ContainerSpec) []oci.SpecOpts {
}
opts = append(opts, oci.WithMounts(mounts))
}
+ if len(spec.CDIDevices) > 0 {
+ opts = append(opts, withStaticCDIRegistry())
+ opts = append(opts, cdispec.WithCDIDevices(spec.CDIDevices...))
+ }
return opts
}
+func withStaticCDIRegistry() oci.SpecOpts {
+ return func(_ context.Context, _ oci.Client, _ *containers.Container, _ *oci.Spec) error {
+ _ = cdi.Configure(cdi.WithAutoRefresh(false))
+ if err := cdi.Refresh(); err != nil {
+ // Invalid specs for other vendors should not block injection of a
+ // resolvable device set for the current container.
+ return nil
+ }
+ return nil
+ }
+}
+
func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContainerRequest) (ContainerInfo, error) {
if req.ID == "" || req.ImageRef == "" {
return ContainerInfo{}, ErrInvalidArgument
diff --git a/internal/containerd/service_apple.go b/internal/containerd/service_apple.go
index b3090666..fee4380e 100644
--- a/internal/containerd/service_apple.go
+++ b/internal/containerd/service_apple.go
@@ -161,6 +161,9 @@ func (s *AppleService) CreateContainer(ctx context.Context, req CreateContainerR
if req.ID == "" || req.ImageRef == "" {
return ContainerInfo{}, ErrInvalidArgument
}
+ if len(req.Spec.CDIDevices) > 0 {
+ return ContainerInfo{}, ErrNotSupported
+ }
if err := s.ensureHealthy(ctx); err != nil {
return ContainerInfo{}, err
}
diff --git a/internal/containerd/types.go b/internal/containerd/types.go
index 31242f0e..7d291509 100644
--- a/internal/containerd/types.go
+++ b/internal/containerd/types.go
@@ -91,7 +91,10 @@ type ContainerSpec struct {
User string
Mounts []MountSpec
DNS []string
- TTY bool
+ // CDIDevices contains fully-qualified CDI device names such as
+ // "nvidia.com/gpu=0" or "amd.com/gpu=0".
+ CDIDevices []string
+ TTY bool
}
type LayerStatus struct {
diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go
index a086aca5..fd5c8766 100644
--- a/internal/handlers/containerd.go
+++ b/internal/handlers/containerd.go
@@ -41,19 +41,25 @@ type ContainerdHandler struct {
policyService *policy.Service
}
+type ContainerGPURequest struct {
+ Devices []string `json:"devices,omitempty"`
+}
+
type CreateContainerRequest struct {
- Snapshotter string `json:"snapshotter,omitempty"`
- RestoreData bool `json:"restore_data,omitempty"`
- Image string `json:"image,omitempty"`
+ Snapshotter string `json:"snapshotter,omitempty"`
+ RestoreData bool `json:"restore_data,omitempty"`
+ Image string `json:"image,omitempty"`
+ GPU *ContainerGPURequest `json:"gpu,omitempty"`
}
type CreateContainerResponse struct {
- ContainerID string `json:"container_id"`
- Image string `json:"image"`
- Snapshotter string `json:"snapshotter"`
- Started bool `json:"started"`
- DataRestored bool `json:"data_restored"`
- HasPreservedData bool `json:"has_preserved_data"`
+ ContainerID string `json:"container_id"`
+ Image string `json:"image"`
+ Snapshotter string `json:"snapshotter"`
+ CDIDevices []string `json:"cdi_devices,omitempty"`
+ Started bool `json:"started"`
+ DataRestored bool `json:"data_restored"`
+ HasPreservedData bool `json:"has_preserved_data"`
}
// codesync(container-create-stream): keep these SSE payloads in sync with
@@ -92,6 +98,7 @@ type GetContainerResponse struct {
Status string `json:"status"`
Namespace string `json:"namespace"`
ContainerPath string `json:"container_path"`
+ CDIDevices []string `json:"cdi_devices,omitempty"`
TaskRunning bool `json:"task_running"`
HasPreservedData bool `json:"has_preserved_data"`
Legacy bool `json:"legacy"`
@@ -217,9 +224,18 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
slog.String("bot_id", botID), slog.Any("error", err))
return nil
}
+ gpu, err := h.manager.ResolveWorkspaceGPU(ctx, botID)
+ if err != nil {
+ h.logger.Error("resolve workspace gpu failed",
+ slog.String("bot_id", botID), slog.Any("error", err))
+ return nil
+ }
if imageOverride != "" {
image = config.NormalizeImageRef(imageOverride)
}
+ if req.GPU != nil {
+ gpu = workspace.WorkspaceGPUConfig{Devices: req.GPU.Devices}
+ }
snapshotter := strings.TrimSpace(req.Snapshotter)
if snapshotter == "" {
@@ -283,7 +299,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
send(createContainerRestoringEvent{Type: "restoring"})
}
- if err := h.manager.StartWithResolvedImage(ctx, botID, image); err != nil {
+ if err := h.manager.StartWithResolvedConfig(ctx, botID, image, gpu); err != nil {
h.logger.Error("container start failed",
slog.String("bot_id", botID), slog.Any("error", err))
sendError("container start failed: " + err.Error())
@@ -293,6 +309,12 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
h.logger.Warn("remember workspace image failed",
slog.String("bot_id", botID), slog.String("image", image), slog.Any("error", err))
}
+ if req.GPU != nil {
+ if err := h.manager.RememberWorkspaceGPU(ctx, botID, gpu); err != nil {
+ h.logger.Warn("remember workspace gpu failed",
+ slog.String("bot_id", botID), slog.Any("error", err))
+ }
+ }
containerID, err := h.manager.ContainerID(ctx, botID)
if err != nil {
@@ -315,6 +337,16 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
h.manager.RecordContainerRunning(ctx, botID, containerID, image)
+ status, statusErr := h.manager.GetContainerInfo(ctx, botID)
+ if statusErr != nil {
+ h.logger.Warn("load container status after start failed",
+ slog.String("bot_id", botID), slog.Any("error", statusErr))
+ }
+ cdiDevices := gpu.Devices
+ if status != nil {
+ cdiDevices = status.CDIDevices
+ }
+
// Phase 3: Complete
send(createContainerCompleteEvent{
Type: "complete",
@@ -322,6 +354,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
ContainerID: containerID,
Image: image,
Snapshotter: snapshotter,
+ CDIDevices: cdiDevices,
Started: true,
DataRestored: dataRestored,
HasPreservedData: h.manager.HasPreservedData(botID),
@@ -357,6 +390,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error {
Status: status.Status,
Namespace: status.Namespace,
ContainerPath: status.ContainerPath,
+ CDIDevices: status.CDIDevices,
TaskRunning: status.TaskRunning,
HasPreservedData: status.HasPreservedData,
Legacy: status.Legacy,
diff --git a/internal/workspace/gpu_labels_test.go b/internal/workspace/gpu_labels_test.go
new file mode 100644
index 00000000..2a0b90b8
--- /dev/null
+++ b/internal/workspace/gpu_labels_test.go
@@ -0,0 +1,29 @@
+package workspace
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestWorkspaceCDIDevicesLabelRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ devices := []string{" nvidia.com/gpu=0 ", "amd.com/gpu=1", "nvidia.com/gpu=0"}
+ value := workspaceCDIDevicesLabelValue(devices)
+ got := workspaceCDIDevicesFromLabels(map[string]string{
+ WorkspaceCDIDevicesLabelKey: value,
+ })
+
+ want := []string{"nvidia.com/gpu=0", "amd.com/gpu=1"}
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("expected devices %v, got %v", want, got)
+ }
+}
+
+func TestWorkspaceCDIDevicesFromLabelsIgnoresMissingValue(t *testing.T) {
+ t.Parallel()
+
+ if got := workspaceCDIDevicesFromLabels(nil); len(got) != 0 {
+ t.Fatalf("expected empty devices for nil labels, got %v", got)
+ }
+}
diff --git a/internal/workspace/image_preference.go b/internal/workspace/image_preference.go
index a298e202..2a2c4a5f 100644
--- a/internal/workspace/image_preference.go
+++ b/internal/workspace/image_preference.go
@@ -16,8 +16,14 @@ import (
const (
workspaceMetadataKey = "workspace"
workspaceImageMetadataKey = "image"
+ workspaceGPUMetadataKey = "gpu"
+ workspaceGPUDevicesKey = "devices"
)
+type WorkspaceGPUConfig struct {
+ Devices []string `json:"devices,omitempty"`
+}
+
func decodeBotMetadata(payload []byte) (map[string]any, error) {
if len(payload) == 0 {
return map[string]any{}, nil
@@ -61,6 +67,54 @@ func workspaceImageFromMetadata(metadata map[string]any) string {
return strings.TrimSpace(image)
}
+func normalizeWorkspaceGPUDevices(devices []string) []string {
+ if len(devices) == 0 {
+ return nil
+ }
+
+ seen := make(map[string]struct{}, len(devices))
+ normalized := make([]string, 0, len(devices))
+ for _, raw := range devices {
+ device := strings.TrimSpace(raw)
+ if device == "" {
+ continue
+ }
+ if _, ok := seen[device]; ok {
+ continue
+ }
+ seen[device] = struct{}{}
+ normalized = append(normalized, device)
+ }
+ return normalized
+}
+
+func workspaceGPUFromMetadata(metadata map[string]any) (WorkspaceGPUConfig, bool) {
+ section := workspaceSection(metadata)
+ raw, ok := section[workspaceGPUMetadataKey]
+ if !ok {
+ return WorkspaceGPUConfig{}, false
+ }
+
+ gpuSection, ok := raw.(map[string]any)
+ if !ok {
+ return WorkspaceGPUConfig{}, true
+ }
+
+ var devices []string
+ switch typed := gpuSection[workspaceGPUDevicesKey].(type) {
+ case []string:
+ devices = append(devices, typed...)
+ case []any:
+ for _, item := range typed {
+ if device, ok := item.(string); ok {
+ devices = append(devices, device)
+ }
+ }
+ }
+
+ return WorkspaceGPUConfig{Devices: normalizeWorkspaceGPUDevices(devices)}, true
+}
+
func withWorkspaceImagePreference(metadata map[string]any, image string) map[string]any {
next := cloneAnyMap(metadata)
section := workspaceSection(next)
@@ -81,6 +135,28 @@ func withoutWorkspaceImagePreference(metadata map[string]any) map[string]any {
return next
}
+func withWorkspaceGPUPreference(metadata map[string]any, gpu WorkspaceGPUConfig) map[string]any {
+ next := cloneAnyMap(metadata)
+ section := workspaceSection(next)
+ section[workspaceGPUMetadataKey] = map[string]any{
+ workspaceGPUDevicesKey: normalizeWorkspaceGPUDevices(gpu.Devices),
+ }
+ next[workspaceMetadataKey] = section
+ return next
+}
+
+func withoutWorkspaceGPUPreference(metadata map[string]any) map[string]any {
+ next := cloneAnyMap(metadata)
+ section := workspaceSection(next)
+ delete(section, workspaceGPUMetadataKey)
+ if len(section) == 0 {
+ delete(next, workspaceMetadataKey)
+ return next
+ }
+ next[workspaceMetadataKey] = section
+ return next
+}
+
func (m *Manager) botWorkspaceImagePreference(ctx context.Context, botID string) (string, error) {
if m.queries == nil {
return "", nil
@@ -132,6 +208,7 @@ func (m *Manager) updateBotWorkspaceImagePreference(ctx context.Context, botID,
ID: botUUID,
DisplayName: row.DisplayName,
AvatarUrl: row.AvatarUrl,
+ Timezone: row.Timezone,
IsActive: row.IsActive,
Metadata: payload,
})
@@ -146,10 +223,82 @@ func (m *Manager) ClearWorkspaceImagePreference(ctx context.Context, botID strin
return m.updateBotWorkspaceImagePreference(ctx, botID, "", true)
}
+func (m *Manager) botWorkspaceGPUPreference(ctx context.Context, botID string) (WorkspaceGPUConfig, bool, error) {
+ if m.queries == nil {
+ return WorkspaceGPUConfig{}, false, nil
+ }
+ botUUID, err := db.ParseUUID(botID)
+ if err != nil {
+ return WorkspaceGPUConfig{}, false, err
+ }
+ row, err := m.queries.GetBotByID(ctx, botUUID)
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ return WorkspaceGPUConfig{}, false, nil
+ }
+ return WorkspaceGPUConfig{}, false, err
+ }
+ metadata, err := decodeBotMetadata(row.Metadata)
+ if err != nil {
+ return WorkspaceGPUConfig{}, false, err
+ }
+ gpu, ok := workspaceGPUFromMetadata(metadata)
+ return gpu, ok, nil
+}
+
+func (m *Manager) updateBotWorkspaceGPUPreference(ctx context.Context, botID string, gpu WorkspaceGPUConfig, clearPreference bool) error {
+ if m.queries == nil {
+ return nil
+ }
+ botUUID, err := db.ParseUUID(botID)
+ if err != nil {
+ return err
+ }
+ row, err := m.queries.GetBotByID(ctx, botUUID)
+ if err != nil {
+ return err
+ }
+ metadata, err := decodeBotMetadata(row.Metadata)
+ if err != nil {
+ return err
+ }
+ if clearPreference {
+ metadata = withoutWorkspaceGPUPreference(metadata)
+ } else {
+ metadata = withWorkspaceGPUPreference(metadata, gpu)
+ }
+ payload, err := json.Marshal(metadata)
+ if err != nil {
+ return err
+ }
+ _, err = m.queries.UpdateBotProfile(ctx, dbsqlc.UpdateBotProfileParams{
+ ID: botUUID,
+ DisplayName: row.DisplayName,
+ AvatarUrl: row.AvatarUrl,
+ Timezone: row.Timezone,
+ IsActive: row.IsActive,
+ Metadata: payload,
+ })
+ return err
+}
+
+func (m *Manager) RememberWorkspaceGPU(ctx context.Context, botID string, gpu WorkspaceGPUConfig) error {
+ gpu.Devices = normalizeWorkspaceGPUDevices(gpu.Devices)
+ return m.updateBotWorkspaceGPUPreference(ctx, botID, gpu, false)
+}
+
+func (m *Manager) ClearWorkspaceGPUPreference(ctx context.Context, botID string) error {
+ return m.updateBotWorkspaceGPUPreference(ctx, botID, WorkspaceGPUConfig{}, true)
+}
+
func (m *Manager) ResolveWorkspaceImage(ctx context.Context, botID string) (string, error) {
return m.resolveWorkspaceImage(ctx, botID)
}
+func (m *Manager) ResolveWorkspaceGPU(ctx context.Context, botID string) (WorkspaceGPUConfig, error) {
+ return m.resolveWorkspaceGPU(ctx, botID)
+}
+
func (m *Manager) resolveWorkspaceImage(ctx context.Context, botID string) (string, error) {
if m.queries != nil {
pgBotID, err := db.ParseUUID(botID)
@@ -174,3 +323,16 @@ func (m *Manager) resolveWorkspaceImage(ctx context.Context, botID string) (stri
return m.imageRef(), nil
}
+
+func (m *Manager) resolveWorkspaceGPU(ctx context.Context, botID string) (WorkspaceGPUConfig, error) {
+ preferredGPU, hasPreference, err := m.botWorkspaceGPUPreference(ctx, botID)
+ if err != nil {
+ return WorkspaceGPUConfig{}, err
+ }
+ if hasPreference {
+ preferredGPU.Devices = normalizeWorkspaceGPUDevices(preferredGPU.Devices)
+ return preferredGPU, nil
+ }
+
+ return WorkspaceGPUConfig{}, nil
+}
diff --git a/internal/workspace/image_preference_test.go b/internal/workspace/image_preference_test.go
index dcbe68fc..6b89e401 100644
--- a/internal/workspace/image_preference_test.go
+++ b/internal/workspace/image_preference_test.go
@@ -51,3 +51,71 @@ func TestWithoutWorkspaceImagePreferenceRemovesOnlyImageKey(t *testing.T) {
t.Fatalf("expected unrelated workspace metadata to remain, got %#v", workspace)
}
}
+
+func TestWorkspaceGPUMetadataRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ metadata := map[string]any{
+ workspaceMetadataKey: map[string]any{
+ "keep": "value",
+ },
+ }
+
+ updated := withWorkspaceGPUPreference(metadata, WorkspaceGPUConfig{
+ Devices: []string{" nvidia.com/gpu=0 ", "amd.com/gpu=1", "nvidia.com/gpu=0"},
+ })
+
+ gpu, ok := workspaceGPUFromMetadata(updated)
+ if !ok {
+ t.Fatal("expected gpu preference to be present")
+ }
+ if got, want := gpu.Devices, []string{"nvidia.com/gpu=0", "amd.com/gpu=1"}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] {
+ t.Fatalf("expected normalized gpu devices %v, got %v", want, got)
+ }
+ workspace, ok := updated[workspaceMetadataKey].(map[string]any)
+ if !ok {
+ t.Fatal("expected workspace metadata section")
+ }
+ if workspace["keep"] != "value" {
+ t.Fatalf("expected existing workspace metadata to be preserved, got %#v", workspace)
+ }
+}
+
+func TestWorkspaceGPUExplicitDisableRemainsPresent(t *testing.T) {
+ t.Parallel()
+
+ metadata := withWorkspaceGPUPreference(map[string]any{}, WorkspaceGPUConfig{})
+
+ gpu, ok := workspaceGPUFromMetadata(metadata)
+ if !ok {
+ t.Fatal("expected gpu preference key to remain present")
+ }
+ if len(gpu.Devices) != 0 {
+ t.Fatalf("expected explicit disable with no devices, got %#v", gpu.Devices)
+ }
+}
+
+func TestWithoutWorkspaceGPUPreferenceRemovesOnlyGPUKey(t *testing.T) {
+ t.Parallel()
+
+ metadata := map[string]any{
+ workspaceMetadataKey: map[string]any{
+ workspaceGPUMetadataKey: map[string]any{
+ workspaceGPUDevicesKey: []any{"nvidia.com/gpu=all"},
+ },
+ "keep": true,
+ },
+ }
+
+ updated := withoutWorkspaceGPUPreference(metadata)
+ if _, ok := workspaceGPUFromMetadata(updated); ok {
+ t.Fatal("expected gpu preference to be cleared")
+ }
+ workspace, ok := updated[workspaceMetadataKey].(map[string]any)
+ if !ok {
+ t.Fatal("expected workspace metadata section to remain")
+ }
+ if workspace["keep"] != true {
+ t.Fatalf("expected unrelated workspace metadata to remain, got %#v", workspace)
+ }
+}
diff --git a/internal/workspace/manager.go b/internal/workspace/manager.go
index 7315e75e..8962e216 100644
--- a/internal/workspace/manager.go
+++ b/internal/workspace/manager.go
@@ -22,11 +22,12 @@ import (
)
const (
- BotLabelKey = "memoh.bot_id"
- WorkspaceLabelKey = "memoh.workspace"
- WorkspaceLabelValue = "v3"
- ContainerPrefix = "workspace-"
- LegacyContainerPrefix = "mcp-"
+ BotLabelKey = "memoh.bot_id"
+ WorkspaceLabelKey = "memoh.workspace"
+ WorkspaceLabelValue = "v3"
+ WorkspaceCDIDevicesLabelKey = "memoh.workspace.cdi_devices"
+ ContainerPrefix = "workspace-"
+ LegacyContainerPrefix = "mcp-"
legacyGRPCPort = 9090
)
@@ -41,6 +42,7 @@ type ContainerStatus struct {
Status string `json:"status"`
Namespace string `json:"namespace"`
ContainerPath string `json:"container_path"`
+ CDIDevices []string `json:"cdi_devices,omitempty"`
TaskRunning bool `json:"task_running"`
HasPreservedData bool `json:"has_preserved_data"`
Legacy bool `json:"legacy"`
@@ -183,23 +185,39 @@ func (m *Manager) EnsureBot(ctx context.Context, botID, imageOverride string) er
if imageOverride != "" {
image = config.NormalizeImageRef(imageOverride)
}
- return m.ensureBotWithImage(ctx, botID, image)
-}
-
-func (m *Manager) ensureBotWithImage(ctx context.Context, botID, image string) error {
- if err := validateBotID(botID); err != nil {
- return err
- }
-
- resolvPath, err := ctr.ResolveConfSource(m.dataRoot())
+ gpu, err := m.resolveWorkspaceGPU(ctx, botID)
if err != nil {
return err
}
+ return m.ensureBotWithImage(ctx, botID, image, gpu)
+}
+
+func workspaceCDIDevicesLabelValue(devices []string) string {
+ devices = normalizeWorkspaceGPUDevices(devices)
+ return strings.Join(devices, ",")
+}
+
+func workspaceCDIDevicesFromLabels(labels map[string]string) []string {
+ if len(labels) == 0 {
+ return nil
+ }
+ value := strings.TrimSpace(labels[WorkspaceCDIDevicesLabelKey])
+ if value == "" {
+ return nil
+ }
+ return normalizeWorkspaceGPUDevices(strings.Split(value, ","))
+}
+
+func (m *Manager) buildWorkspaceContainerSpec(botID string, gpu WorkspaceGPUConfig) (ctr.ContainerSpec, error) {
+ resolvPath, err := ctr.ResolveConfSource(m.dataRoot())
+ if err != nil {
+ return ctr.ContainerSpec{}, err
+ }
runtimeDir := m.cfg.RuntimePath()
sockDir := m.socketDir(botID)
if err := os.MkdirAll(sockDir, 0o750); err != nil {
- return fmt.Errorf("create socket dir: %w", err)
+ return ctr.ContainerSpec{}, fmt.Errorf("create socket dir: %w", err)
}
mounts := []ctr.MountSpec{
@@ -229,19 +247,37 @@ func (m *Manager) ensureBotWithImage(ctx context.Context, botID, image string) e
env = append(env, tzEnv...)
env = append(env, "BRIDGE_SOCKET_PATH=/run/memoh/bridge.sock")
+ return ctr.ContainerSpec{
+ Cmd: []string{"/opt/memoh/bridge"},
+ Mounts: mounts,
+ Env: env,
+ CDIDevices: normalizeWorkspaceGPUDevices(gpu.Devices),
+ }, nil
+}
+
+func (m *Manager) ensureBotWithImage(ctx context.Context, botID, image string, gpu WorkspaceGPUConfig) error {
+ if err := validateBotID(botID); err != nil {
+ return err
+ }
+ spec, err := m.buildWorkspaceContainerSpec(botID, gpu)
+ if err != nil {
+ return err
+ }
+
+ labels := map[string]string{
+ BotLabelKey: botID,
+ WorkspaceLabelKey: WorkspaceLabelValue,
+ }
+ if value := workspaceCDIDevicesLabelValue(gpu.Devices); value != "" {
+ labels[WorkspaceCDIDevicesLabelKey] = value
+ }
+
_, err = m.service.CreateContainer(ctx, ctr.CreateContainerRequest{
ID: ContainerPrefix + botID,
ImageRef: image,
Snapshotter: m.cfg.Snapshotter,
- Labels: map[string]string{
- BotLabelKey: botID,
- WorkspaceLabelKey: WorkspaceLabelValue,
- },
- Spec: ctr.ContainerSpec{
- Cmd: []string{"/opt/memoh/bridge"},
- Mounts: mounts,
- Env: env,
- },
+ Labels: labels,
+ Spec: spec,
})
if err == nil {
return nil
@@ -275,7 +311,11 @@ func (m *Manager) Start(ctx context.Context, botID string) error {
if err != nil {
return err
}
- return m.startWithResolvedImage(ctx, botID, image)
+ gpu, err := m.resolveWorkspaceGPU(ctx, botID)
+ if err != nil {
+ return err
+ }
+ return m.startWithResolvedConfig(ctx, botID, image, gpu)
}
// StartWithImage creates and starts the MCP container for a bot.
@@ -286,7 +326,11 @@ func (m *Manager) StartWithImage(ctx context.Context, botID, imageOverride strin
if image == "" {
return m.Start(ctx, botID)
}
- return m.startWithResolvedImage(ctx, botID, config.NormalizeImageRef(image))
+ gpu, err := m.resolveWorkspaceGPU(ctx, botID)
+ if err != nil {
+ return err
+ }
+ return m.startWithResolvedConfig(ctx, botID, config.NormalizeImageRef(image), gpu)
}
// StartWithResolvedImage creates and starts the workspace container for a bot
@@ -296,10 +340,22 @@ func (m *Manager) StartWithResolvedImage(ctx context.Context, botID, image strin
if image == "" {
return errors.New("image is required")
}
- return m.startWithResolvedImage(ctx, botID, image)
+ gpu, err := m.resolveWorkspaceGPU(ctx, botID)
+ if err != nil {
+ return err
+ }
+ return m.startWithResolvedConfig(ctx, botID, image, gpu)
}
-func (m *Manager) startWithResolvedImage(ctx context.Context, botID, image string) error {
+func (m *Manager) StartWithResolvedConfig(ctx context.Context, botID, image string, gpu WorkspaceGPUConfig) error {
+ image = strings.TrimSpace(image)
+ if image == "" {
+ return errors.New("image is required")
+ }
+ return m.startWithResolvedConfig(ctx, botID, image, gpu)
+}
+
+func (m *Manager) startWithResolvedConfig(ctx context.Context, botID, image string, gpu WorkspaceGPUConfig) error {
containerID := m.resolveContainerID(ctx, botID)
// Before creating a new container, check for an orphaned snapshot
@@ -311,7 +367,7 @@ func (m *Manager) startWithResolvedImage(ctx context.Context, botID, image strin
m.recoverOrphanedSnapshot(ctx, botID)
}
- if err := m.ensureBotWithImage(ctx, botID, image); err != nil {
+ if err := m.ensureBotWithImage(ctx, botID, image, gpu); err != nil {
return err
}
diff --git a/internal/workspace/manager_lifecycle.go b/internal/workspace/manager_lifecycle.go
index 73fdbbd5..7b859d0e 100644
--- a/internal/workspace/manager_lifecycle.go
+++ b/internal/workspace/manager_lifecycle.go
@@ -202,6 +202,10 @@ func (m *Manager) GetContainerInfo(ctx context.Context, botID string) (*Containe
if parseErr == nil {
row, dbErr := m.queries.GetContainerByBotID(ctx, pgBotID)
if dbErr == nil {
+ cdiDevices := []string(nil)
+ if liveInfo, liveErr := m.service.GetContainer(ctx, row.ContainerID); liveErr == nil {
+ cdiDevices = workspaceCDIDevicesFromLabels(liveInfo.Labels)
+ }
createdAt := time.Time{}
if row.CreatedAt.Valid {
createdAt = row.CreatedAt.Time
@@ -216,6 +220,7 @@ func (m *Manager) GetContainerInfo(ctx context.Context, botID string) (*Containe
Status: row.Status,
Namespace: row.Namespace,
ContainerPath: row.ContainerPath,
+ CDIDevices: cdiDevices,
TaskRunning: m.isTaskRunning(ctx, row.ContainerID),
HasPreservedData: m.HasPreservedData(botID),
Legacy: m.IsLegacyContainer(ctx, row.ContainerID),
@@ -242,6 +247,7 @@ func (m *Manager) GetContainerInfo(ctx context.Context, botID string) (*Containe
Image: info.Image,
Status: "unknown",
Namespace: m.namespace,
+ CDIDevices: workspaceCDIDevicesFromLabels(info.Labels),
TaskRunning: m.isTaskRunning(ctx, containerID),
HasPreservedData: m.HasPreservedData(botID),
Legacy: m.IsLegacyContainer(ctx, containerID),
@@ -270,7 +276,7 @@ func (m *Manager) SetupBotContainer(ctx context.Context, botID string) error {
return err
}
- if err := m.startWithResolvedImage(ctx, botID, image); err != nil {
+ if err := m.StartWithResolvedImage(ctx, botID, image); err != nil {
m.logger.Error("setup bot container: start failed",
slog.String("bot_id", botID),
slog.Any("error", err))
diff --git a/internal/workspace/versioning.go b/internal/workspace/versioning.go
index b5c70d9c..728ede12 100644
--- a/internal/workspace/versioning.go
+++ b/internal/workspace/versioning.go
@@ -370,7 +370,7 @@ func (m *Manager) replaceContainerSnapshot(ctx context.Context, botID, container
if err := m.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{CleanupSnapshot: false}); err != nil {
return err
}
- spec, err := m.buildVersionSpec(botID)
+ spec, err := m.buildVersionSpec(ctx, botID, workspaceCDIDevicesFromLabels(info.Labels))
if err != nil {
return err
}
@@ -402,47 +402,15 @@ func (m *Manager) replaceContainerSnapshot(ctx context.Context, botID, container
return nil
}
-func (m *Manager) buildVersionSpec(botID string) (ctr.ContainerSpec, error) {
- resolvPath, err := ctr.ResolveConfSource(m.dataRoot())
- if err != nil {
- return ctr.ContainerSpec{}, err
+func (m *Manager) buildVersionSpec(ctx context.Context, botID string, cdiDevices []string) (ctr.ContainerSpec, error) {
+ if len(cdiDevices) == 0 {
+ gpu, err := m.resolveWorkspaceGPU(ctx, botID)
+ if err != nil {
+ return ctr.ContainerSpec{}, err
+ }
+ cdiDevices = gpu.Devices
}
-
- runtimeDir := m.cfg.RuntimePath()
- sockDir := m.socketDir(botID)
-
- mounts := []ctr.MountSpec{
- {
- Destination: "/etc/resolv.conf",
- Type: "bind",
- Source: resolvPath,
- Options: []string{"rbind", "ro"},
- },
- {
- Destination: "/opt/memoh",
- Type: "bind",
- Source: runtimeDir,
- Options: []string{"rbind", "ro"},
- },
- {
- Destination: "/run/memoh",
- Type: "bind",
- Source: sockDir,
- Options: []string{"rbind", "rw"},
- },
- }
- tzMounts, tzEnv := ctr.TimezoneSpec()
- mounts = append(mounts, tzMounts...)
-
- env := make([]string, 0, len(tzEnv)+1)
- env = append(env, tzEnv...)
- env = append(env, "BRIDGE_SOCKET_PATH=/run/memoh/bridge.sock")
-
- return ctr.ContainerSpec{
- Cmd: []string{"/opt/memoh/bridge"},
- Mounts: mounts,
- Env: env,
- }, nil
+ return m.buildWorkspaceContainerSpec(botID, WorkspaceGPUConfig{Devices: cdiDevices})
}
func (m *Manager) safeStopTask(ctx context.Context, containerID string) error {
diff --git a/packages/sdk/src/types.gen.ts b/packages/sdk/src/types.gen.ts
index 620727f1..73d38851 100644
--- a/packages/sdk/src/types.gen.ts
+++ b/packages/sdk/src/types.gen.ts
@@ -738,18 +738,24 @@ export type HandlersChannelMeta = {
user_config_schema?: ChannelConfigSchema;
};
+export type HandlersContainerGpuRequest = {
+ devices?: Array;
+};
+
export type HandlersContextUsage = {
context_window?: number;
used_tokens?: number;
};
export type HandlersCreateContainerRequest = {
+ gpu?: HandlersContainerGpuRequest;
image?: string;
restore_data?: boolean;
snapshotter?: string;
};
export type HandlersCreateContainerResponse = {
+ cdi_devices?: Array;
container_id?: string;
data_restored?: boolean;
has_preserved_data?: boolean;
@@ -830,6 +836,7 @@ export type HandlersFsWriteRequest = {
};
export type HandlersGetContainerResponse = {
+ cdi_devices?: Array;
container_id?: string;
container_path?: string;
created_at?: string;
diff --git a/spec/docs.go b/spec/docs.go
index 08931bdd..fb827094 100644
--- a/spec/docs.go
+++ b/spec/docs.go
@@ -10473,6 +10473,17 @@ const docTemplate = `{
}
}
},
+ "handlers.ContainerGPURequest": {
+ "type": "object",
+ "properties": {
+ "devices": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ }
+ },
"handlers.ContextUsage": {
"type": "object",
"properties": {
@@ -10487,6 +10498,9 @@ const docTemplate = `{
"handlers.CreateContainerRequest": {
"type": "object",
"properties": {
+ "gpu": {
+ "$ref": "#/definitions/handlers.ContainerGPURequest"
+ },
"image": {
"type": "string"
},
@@ -10501,6 +10515,12 @@ const docTemplate = `{
"handlers.CreateContainerResponse": {
"type": "object",
"properties": {
+ "cdi_devices": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
"container_id": {
"type": "string"
},
@@ -10692,6 +10712,12 @@ const docTemplate = `{
"handlers.GetContainerResponse": {
"type": "object",
"properties": {
+ "cdi_devices": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
"container_id": {
"type": "string"
},
diff --git a/spec/swagger.json b/spec/swagger.json
index c0d26df8..4221e35a 100644
--- a/spec/swagger.json
+++ b/spec/swagger.json
@@ -10464,6 +10464,17 @@
}
}
},
+ "handlers.ContainerGPURequest": {
+ "type": "object",
+ "properties": {
+ "devices": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ }
+ },
"handlers.ContextUsage": {
"type": "object",
"properties": {
@@ -10478,6 +10489,9 @@
"handlers.CreateContainerRequest": {
"type": "object",
"properties": {
+ "gpu": {
+ "$ref": "#/definitions/handlers.ContainerGPURequest"
+ },
"image": {
"type": "string"
},
@@ -10492,6 +10506,12 @@
"handlers.CreateContainerResponse": {
"type": "object",
"properties": {
+ "cdi_devices": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
"container_id": {
"type": "string"
},
@@ -10683,6 +10703,12 @@
"handlers.GetContainerResponse": {
"type": "object",
"properties": {
+ "cdi_devices": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
"container_id": {
"type": "string"
},
diff --git a/spec/swagger.yaml b/spec/swagger.yaml
index 2f0de79b..2afefa92 100644
--- a/spec/swagger.yaml
+++ b/spec/swagger.yaml
@@ -1222,6 +1222,13 @@ definitions:
user_config_schema:
$ref: '#/definitions/channel.ConfigSchema'
type: object
+ handlers.ContainerGPURequest:
+ properties:
+ devices:
+ items:
+ type: string
+ type: array
+ type: object
handlers.ContextUsage:
properties:
context_window:
@@ -1231,6 +1238,8 @@ definitions:
type: object
handlers.CreateContainerRequest:
properties:
+ gpu:
+ $ref: '#/definitions/handlers.ContainerGPURequest'
image:
type: string
restore_data:
@@ -1240,6 +1249,10 @@ definitions:
type: object
handlers.CreateContainerResponse:
properties:
+ cdi_devices:
+ items:
+ type: string
+ type: array
container_id:
type: string
data_restored:
@@ -1363,6 +1376,10 @@ definitions:
type: object
handlers.GetContainerResponse:
properties:
+ cdi_devices:
+ items:
+ type: string
+ type: array
container_id:
type: string
container_path: