mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(skills): add effective skill resolution and actions (#377)
* feat(skills): add effective skill resolution and actions * refactor(workspace): normalize skill-related env and prompt * chore(api): regenerate skills OpenAPI and SDK artifacts * feat(web): surface effective skill state in console * test(skills): cover API and runtime effective state * fix(web): show adopt action for discovered skills * fix(web): align skill header and show stateful visibility icon * refactor(web): compact skill metadata on narrow layouts * fix(web): constrain long skill text in cards * refactor(skills): narrow default discovery roots * fix(skills): harden managed skill path validation * feat: add path in the results of `use_skill` --------- Co-authored-by: Acbox <acbox0328@gmail.com>
This commit is contained in:
@@ -619,6 +619,7 @@ func (a *Agent) assembleTools(ctx context.Context, cfg RunConfig, emitter tools.
|
||||
skillsMap[s.Name] = tools.SkillDetail{
|
||||
Description: s.Description,
|
||||
Content: s.Content,
|
||||
Path: s.Path,
|
||||
}
|
||||
}
|
||||
session := tools.SessionContext{
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
skillset "github.com/memohai/memoh/internal/skills"
|
||||
)
|
||||
|
||||
//go:embed prompts/*.md
|
||||
@@ -205,7 +207,8 @@ func buildSkillsSection(skills []SkillEntry) string {
|
||||
})
|
||||
var sb strings.Builder
|
||||
sb.WriteString("## Skills\n\n")
|
||||
sb.WriteString("Skills are stored in `{{home}}/skills/`. ")
|
||||
sb.WriteString("Memoh-managed skills are stored in `" + skillset.ManagedDir() + "/`. ")
|
||||
sb.WriteString("Compatible external skill directories inside the bot container may also be discovered automatically. ")
|
||||
sb.WriteString("Each skill is a `SKILL.md` file inside a named subdirectory.\n\n")
|
||||
sb.WriteString("Call `use_skill` with the skill name to load its full instructions before following them. ")
|
||||
sb.WriteString("Only activate a skill when it is relevant to the current task.\n\n")
|
||||
|
||||
@@ -61,6 +61,7 @@ func (*SkillProvider) Tools(_ context.Context, session SessionContext) ([]sdk.To
|
||||
"skillName": skillName,
|
||||
"description": skill.Description,
|
||||
"content": skill.Content,
|
||||
"path": skill.Path,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUseSkillReturnsPath(t *testing.T) {
|
||||
provider := NewSkillProvider(nil)
|
||||
|
||||
toolset, err := provider.Tools(context.Background(), SessionContext{
|
||||
Skills: map[string]SkillDetail{
|
||||
"pdf": {
|
||||
Description: "Read PDF instructions",
|
||||
Content: "Use a PDF-aware workflow.",
|
||||
Path: "/data/.agents/skills/pdf",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Tools returned error: %v", err)
|
||||
}
|
||||
if len(toolset) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(toolset))
|
||||
}
|
||||
|
||||
result, err := toolset[0].Execute(nil, map[string]any{
|
||||
"skillName": "pdf",
|
||||
"reason": "Need to process a PDF attachment",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute returned error: %v", err)
|
||||
}
|
||||
|
||||
payload, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("result type = %T, want map[string]any", result)
|
||||
}
|
||||
if got := payload["path"]; got != "/data/.agents/skills/pdf" {
|
||||
t.Fatalf("path = %#v, want %q", got, "/data/.agents/skills/pdf")
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
type SkillDetail struct {
|
||||
Description string
|
||||
Content string
|
||||
Path string
|
||||
}
|
||||
|
||||
// StreamEventType identifies the kind of stream event emitted by tools.
|
||||
|
||||
@@ -30,6 +30,7 @@ type SkillEntry struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string
|
||||
Path string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ type SkillEntry struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string
|
||||
Path string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
@@ -703,6 +704,7 @@ func normalizeGatewaySkill(entry SkillEntry) (agentpkg.SkillEntry, bool) {
|
||||
Name: name,
|
||||
Description: description,
|
||||
Content: content,
|
||||
Path: strings.TrimSpace(entry.Path),
|
||||
Metadata: entry.Metadata,
|
||||
}, true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package flow
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeGatewaySkillPreservesPath(t *testing.T) {
|
||||
got, ok := normalizeGatewaySkill(SkillEntry{
|
||||
Name: "pdf",
|
||||
Description: "Read PDF instructions",
|
||||
Content: "Use a PDF-aware workflow.",
|
||||
Path: " /data/.agents/skills/pdf ",
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("normalizeGatewaySkill returned ok=false")
|
||||
}
|
||||
if got.Path != "/data/.agents/skills/pdf" {
|
||||
t.Fatalf("path = %q, want %q", got.Path, "/data/.agents/skills/pdf")
|
||||
}
|
||||
}
|
||||
@@ -175,6 +175,7 @@ func (h *ContainerdHandler) Register(e *echo.Echo) {
|
||||
group.GET("/skills", h.ListSkills)
|
||||
group.POST("/skills", h.UpsertSkills)
|
||||
group.DELETE("/skills", h.DeleteSkills)
|
||||
group.POST("/skills/actions", h.ApplySkillAction)
|
||||
// Terminal routes
|
||||
group.GET("/terminal", h.GetTerminalInfo)
|
||||
group.GET("/terminal/ws", h.HandleTerminalWS)
|
||||
|
||||
+95
-149
@@ -8,20 +8,22 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
skillset "github.com/memohai/memoh/internal/skills"
|
||||
)
|
||||
|
||||
const skillsDirPath = config.DefaultDataMount + "/skills"
|
||||
|
||||
type SkillItem struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Raw string `json:"raw"`
|
||||
SourcePath string `json:"source_path,omitempty"`
|
||||
SourceRoot string `json:"source_root,omitempty"`
|
||||
SourceKind string `json:"source_kind,omitempty"`
|
||||
Managed bool `json:"managed,omitempty"`
|
||||
State string `json:"state,omitempty"`
|
||||
ShadowedBy string `json:"shadowed_by,omitempty"`
|
||||
}
|
||||
|
||||
type SkillsResponse struct {
|
||||
@@ -36,12 +38,17 @@ type SkillsDeleteRequest struct {
|
||||
Names []string `json:"names"`
|
||||
}
|
||||
|
||||
type SkillsActionRequest struct {
|
||||
Action string `json:"action"`
|
||||
TargetPath string `json:"target_path"`
|
||||
}
|
||||
|
||||
type skillsOpResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
// ListSkills godoc
|
||||
// @Summary List skills from data directory
|
||||
// @Summary List skills from the bot container
|
||||
// @Tags containerd
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Success 200 {object} SkillsResponse
|
||||
@@ -54,18 +61,16 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
skills, err := h.loadSkillsFromContainer(c.Request().Context(), botID)
|
||||
|
||||
skills, err := h.listSkillsFromContainer(c.Request().Context(), botID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
for i := range skills {
|
||||
skills[i].Raw = skills[i].Content
|
||||
}
|
||||
return c.JSON(http.StatusOK, SkillsResponse{Skills: skills})
|
||||
}
|
||||
|
||||
// UpsertSkills godoc
|
||||
// @Summary Upload skills into data directory
|
||||
// @Summary Upload skills into Memoh-managed directory
|
||||
// @Tags containerd
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body SkillsUpsertRequest true "Skills payload"
|
||||
@@ -79,6 +84,7 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req SkillsUpsertRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
@@ -94,11 +100,11 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error {
|
||||
}
|
||||
|
||||
for _, raw := range req.Skills {
|
||||
parsed := parseSkillFile(raw, "")
|
||||
if !isValidSkillName(parsed.Name) {
|
||||
parsed := skillset.ParseFile(raw, "")
|
||||
dirPath, dirErr := skillset.ManagedSkillDirForName(parsed.Name)
|
||||
if dirErr != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "skill must have a valid name in YAML frontmatter")
|
||||
}
|
||||
dirPath := path.Join(skillsDirPath, parsed.Name)
|
||||
if err := client.Mkdir(ctx, dirPath); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("mkdir failed: %v", err))
|
||||
}
|
||||
@@ -112,7 +118,7 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error {
|
||||
}
|
||||
|
||||
// DeleteSkills godoc
|
||||
// @Summary Delete skills from data directory
|
||||
// @Summary Delete Memoh-managed skills
|
||||
// @Tags containerd
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body SkillsDeleteRequest true "Delete skills payload"
|
||||
@@ -126,6 +132,7 @@ func (h *ContainerdHandler) DeleteSkills(c echo.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req SkillsDeleteRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
@@ -142,160 +149,99 @@ func (h *ContainerdHandler) DeleteSkills(c echo.Context) error {
|
||||
|
||||
for _, name := range req.Names {
|
||||
skillName := strings.TrimSpace(name)
|
||||
if !isValidSkillName(skillName) {
|
||||
managedDir, dirErr := skillset.ManagedSkillDirForName(skillName)
|
||||
if dirErr != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid skill name")
|
||||
}
|
||||
_ = client.DeleteFile(ctx, path.Join(skillsDirPath, skillName), true)
|
||||
if _, statErr := client.Stat(ctx, managedDir); statErr != nil {
|
||||
return fsHTTPError(statErr)
|
||||
}
|
||||
if err := client.DeleteFile(ctx, managedDir, true); err != nil {
|
||||
return fsHTTPError(err)
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, skillsOpResponse{OK: true})
|
||||
}
|
||||
|
||||
// LoadSkills loads all skills from the container for the given bot.
|
||||
func (h *ContainerdHandler) LoadSkills(ctx context.Context, botID string) ([]SkillItem, error) {
|
||||
return h.loadSkillsFromContainer(ctx, botID)
|
||||
// ApplySkillAction godoc
|
||||
// @Summary Apply an action to a discovered or managed skill source
|
||||
// @Tags containerd
|
||||
// @Param bot_id path string true "Bot ID"
|
||||
// @Param payload body SkillsActionRequest true "Skill action payload"
|
||||
// @Success 200 {object} skillsOpResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /bots/{bot_id}/container/skills/actions [post].
|
||||
func (h *ContainerdHandler) ApplySkillAction(c echo.Context) error {
|
||||
botID, err := h.requireBotAccess(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req SkillsActionRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
|
||||
ctx := c.Request().Context()
|
||||
client, err := h.getGRPCClient(ctx, botID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("container not reachable: %v", err))
|
||||
}
|
||||
|
||||
if err := skillset.ApplyAction(ctx, client, skillset.ActionRequest{
|
||||
Action: req.Action,
|
||||
TargetPath: req.TargetPath,
|
||||
}); err != nil {
|
||||
return fsHTTPError(err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, skillsOpResponse{OK: true})
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) loadSkillsFromContainer(ctx context.Context, botID string) ([]SkillItem, error) {
|
||||
// LoadSkills loads the effective skills from the container for the given bot.
|
||||
func (h *ContainerdHandler) LoadSkills(ctx context.Context, botID string) ([]SkillItem, error) {
|
||||
client, err := h.getGRPCClient(ctx, botID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries, err := client.ListDirAll(ctx, skillsDirPath, false)
|
||||
items, err := skillset.LoadEffective(ctx, client)
|
||||
if err != nil {
|
||||
return []SkillItem{}, nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var skills []SkillItem
|
||||
for _, entry := range entries {
|
||||
if !entry.GetIsDir() {
|
||||
if path.Base(entry.GetPath()) == "SKILL.md" {
|
||||
filePath := path.Join(skillsDirPath, "SKILL.md")
|
||||
raw, readErr := readContainerSkillFile(ctx, client, filePath)
|
||||
if readErr != nil {
|
||||
continue
|
||||
}
|
||||
parsed := parseSkillFile(raw, "default")
|
||||
skills = append(skills, skillItemFromParsed(parsed, raw))
|
||||
}
|
||||
continue
|
||||
}
|
||||
name := path.Base(entry.GetPath())
|
||||
if name == "" || name == "." {
|
||||
continue
|
||||
}
|
||||
filePath := path.Join(skillsDirPath, name, "SKILL.md")
|
||||
raw, readErr := readContainerSkillFile(ctx, client, filePath)
|
||||
if readErr != nil {
|
||||
continue
|
||||
}
|
||||
parsed := parseSkillFile(raw, name)
|
||||
skills = append(skills, skillItemFromParsed(parsed, raw))
|
||||
}
|
||||
return skills, nil
|
||||
return skillItemsFromEntries(items), nil
|
||||
}
|
||||
|
||||
func readContainerSkillFile(ctx context.Context, client *bridge.Client, filePath string) (string, error) {
|
||||
resp, err := client.ReadFile(ctx, filePath, 0, 0)
|
||||
func (h *ContainerdHandler) listSkillsFromContainer(ctx context.Context, botID string) ([]SkillItem, error) {
|
||||
client, err := h.getGRPCClient(ctx, botID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return resp.GetContent(), nil
|
||||
items, err := skillset.List(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return skillItemsFromEntries(items), nil
|
||||
}
|
||||
|
||||
func skillItemFromParsed(parsed parsedSkill, raw string) SkillItem {
|
||||
return SkillItem{
|
||||
Name: parsed.Name,
|
||||
Description: parsed.Description,
|
||||
Content: parsed.Content,
|
||||
Metadata: parsed.Metadata,
|
||||
Raw: raw,
|
||||
func skillItemsFromEntries(entries []skillset.Entry) []SkillItem {
|
||||
items := make([]SkillItem, len(entries))
|
||||
for i, entry := range entries {
|
||||
items[i] = SkillItem{
|
||||
Name: entry.Name,
|
||||
Description: entry.Description,
|
||||
Content: entry.Content,
|
||||
Metadata: entry.Metadata,
|
||||
Raw: entry.Raw,
|
||||
SourcePath: entry.SourcePath,
|
||||
SourceRoot: entry.SourceRoot,
|
||||
SourceKind: entry.SourceKind,
|
||||
Managed: entry.Managed,
|
||||
State: entry.State,
|
||||
ShadowedBy: entry.ShadowedBy,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- parsing logic (unchanged) ---
|
||||
|
||||
type parsedSkill struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// parseSkillFile parses a SKILL.md file with YAML frontmatter delimited by "---".
|
||||
func parseSkillFile(raw string, fallbackName string) parsedSkill {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
result := parsedSkill{
|
||||
Name: strings.TrimSpace(fallbackName),
|
||||
Content: trimmed,
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "---") {
|
||||
return normalizeParsedSkill(result)
|
||||
}
|
||||
|
||||
rest := trimmed[3:]
|
||||
rest = strings.TrimLeft(rest, " \t")
|
||||
if len(rest) > 0 && rest[0] == '\n' {
|
||||
rest = rest[1:]
|
||||
} else if len(rest) > 1 && rest[0] == '\r' && rest[1] == '\n' {
|
||||
rest = rest[2:]
|
||||
}
|
||||
closingIdx := strings.Index(rest, "\n---")
|
||||
if closingIdx < 0 {
|
||||
return normalizeParsedSkill(result)
|
||||
}
|
||||
|
||||
frontmatterRaw := rest[:closingIdx]
|
||||
body := rest[closingIdx+4:]
|
||||
body = strings.TrimLeft(body, "\r\n")
|
||||
result.Content = body
|
||||
|
||||
var fm struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
}
|
||||
if err := yaml.Unmarshal([]byte(frontmatterRaw), &fm); err != nil {
|
||||
return normalizeParsedSkill(result)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(fm.Name) != "" {
|
||||
result.Name = strings.TrimSpace(fm.Name)
|
||||
}
|
||||
result.Description = strings.TrimSpace(fm.Description)
|
||||
result.Metadata = fm.Metadata
|
||||
|
||||
return normalizeParsedSkill(result)
|
||||
}
|
||||
|
||||
func normalizeParsedSkill(skill parsedSkill) parsedSkill {
|
||||
if strings.TrimSpace(skill.Name) == "" {
|
||||
skill.Name = "default"
|
||||
}
|
||||
skill.Name = strings.TrimSpace(skill.Name)
|
||||
skill.Description = strings.TrimSpace(skill.Description)
|
||||
skill.Content = strings.TrimSpace(skill.Content)
|
||||
|
||||
if skill.Description == "" {
|
||||
skill.Description = skill.Name
|
||||
}
|
||||
if skill.Content == "" {
|
||||
skill.Content = skill.Description
|
||||
}
|
||||
|
||||
return skill
|
||||
}
|
||||
|
||||
func isValidSkillName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, "..") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, "/") || strings.Contains(name, "\\") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return items
|
||||
}
|
||||
|
||||
@@ -1,42 +1,722 @@
|
||||
package handlers
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
func TestParseSkillFile_NoFrontmatterFallbacks(t *testing.T) {
|
||||
raw := "# Use this skill\n\nDo something useful."
|
||||
got := parseSkillFile(raw, "plain-skill")
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/labstack/echo/v4"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
if got.Name != "plain-skill" {
|
||||
t.Fatalf("expected name plain-skill, got %q", got.Name)
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/agent"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
skillset "github.com/memohai/memoh/internal/skills"
|
||||
"github.com/memohai/memoh/internal/workspace"
|
||||
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||
)
|
||||
|
||||
func TestListSkillsAPIReportsEffectiveShadowedAndSourceMetadata(t *testing.T) {
|
||||
env := newSkillsTestEnv(t)
|
||||
env.writeSkillFile(t, path.Join(skillset.ManagedDir(), "alpha", "SKILL.md"), managedSkillRaw("alpha", "Managed Alpha"))
|
||||
env.writeSkillFile(t, path.Join("/data/.agents/skills", "alpha", "SKILL.md"), managedSkillRaw("alpha", "Compat Alpha"))
|
||||
env.writeSkillFile(t, path.Join("/data/.agents/skills", "beta", "SKILL.md"), managedSkillRaw("beta", "Compat Beta"))
|
||||
|
||||
rec, err := env.callJSON(t, http.MethodGet, "/bots/:bot_id/container/skills", nil, env.handler.ListSkills)
|
||||
if err != nil {
|
||||
t.Fatalf("ListSkills returned error: %v", err)
|
||||
}
|
||||
if got.Description != "plain-skill" {
|
||||
t.Fatalf("expected description plain-skill, got %q", got.Description)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("ListSkills status = %d, want 200", rec.Code)
|
||||
}
|
||||
if got.Content != raw {
|
||||
t.Fatalf("expected content to keep original markdown, got %q", got.Content)
|
||||
|
||||
var resp SkillsResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode skills response: %v", err)
|
||||
}
|
||||
if len(resp.Skills) != 3 {
|
||||
t.Fatalf("expected 3 skills, got %d", len(resp.Skills))
|
||||
}
|
||||
|
||||
alphaManaged := mustFindSkillByPath(t, resp.Skills, path.Join(skillset.ManagedDir(), "alpha", "SKILL.md"))
|
||||
if !alphaManaged.Managed {
|
||||
t.Fatalf("managed alpha should be managed: %+v", alphaManaged)
|
||||
}
|
||||
if alphaManaged.State != skillset.StateEffective {
|
||||
t.Fatalf("managed alpha state = %q, want %q", alphaManaged.State, skillset.StateEffective)
|
||||
}
|
||||
if alphaManaged.SourceKind != skillset.SourceKindManaged {
|
||||
t.Fatalf("managed alpha source_kind = %q, want %q", alphaManaged.SourceKind, skillset.SourceKindManaged)
|
||||
}
|
||||
|
||||
alphaCompatPath := path.Join("/data/.agents/skills", "alpha", "SKILL.md")
|
||||
alphaCompat := mustFindSkillByPath(t, resp.Skills, alphaCompatPath)
|
||||
if alphaCompat.Managed {
|
||||
t.Fatalf("compat alpha should not be managed: %+v", alphaCompat)
|
||||
}
|
||||
if alphaCompat.State != skillset.StateShadowed {
|
||||
t.Fatalf("compat alpha state = %q, want %q", alphaCompat.State, skillset.StateShadowed)
|
||||
}
|
||||
if alphaCompat.ShadowedBy != alphaManaged.SourcePath {
|
||||
t.Fatalf("compat alpha shadowed_by = %q, want %q", alphaCompat.ShadowedBy, alphaManaged.SourcePath)
|
||||
}
|
||||
if alphaCompat.SourceKind != skillset.SourceKindCompat {
|
||||
t.Fatalf("compat alpha source_kind = %q, want %q", alphaCompat.SourceKind, skillset.SourceKindCompat)
|
||||
}
|
||||
|
||||
betaCompat := mustFindSkillByPath(t, resp.Skills, path.Join("/data/.agents/skills", "beta", "SKILL.md"))
|
||||
if betaCompat.State != skillset.StateEffective {
|
||||
t.Fatalf("beta compat state = %q, want %q", betaCompat.State, skillset.StateEffective)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(env.localPath(skillset.IndexFilePath)); err != nil {
|
||||
t.Fatalf("expected derived skill index to be written: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSkillFile_FrontmatterDescriptionFallback(t *testing.T) {
|
||||
raw := "---\nname: hello-skill\n---\n\nBody content"
|
||||
got := parseSkillFile(raw, "fallback")
|
||||
func TestSkillsActionsAPIAdoptDisableEnableAndDeleteManaged(t *testing.T) {
|
||||
env := newSkillsTestEnv(t)
|
||||
externalPath := path.Join("/data/.agents/skills", "alpha", "SKILL.md")
|
||||
managedPath := path.Join(skillset.ManagedDir(), "alpha", "SKILL.md")
|
||||
env.writeSkillFile(t, externalPath, managedSkillRaw("alpha", "Compat Alpha"))
|
||||
|
||||
if got.Name != "hello-skill" {
|
||||
t.Fatalf("expected frontmatter name hello-skill, got %q", got.Name)
|
||||
rec, err := env.callJSON(t, http.MethodPost, "/bots/:bot_id/container/skills/actions", SkillsActionRequest{
|
||||
Action: skillset.ActionAdopt,
|
||||
TargetPath: externalPath,
|
||||
}, env.handler.ApplySkillAction)
|
||||
if err != nil {
|
||||
t.Fatalf("adopt returned error: %v", err)
|
||||
}
|
||||
if got.Description != "hello-skill" {
|
||||
t.Fatalf("expected description fallback to name, got %q", got.Description)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("adopt status = %d, want 200", rec.Code)
|
||||
}
|
||||
if got.Content != "Body content" {
|
||||
t.Fatalf("expected content Body content, got %q", got.Content)
|
||||
if _, err := os.Stat(env.localPath(managedPath)); err != nil {
|
||||
t.Fatalf("expected managed skill after adopt: %v", err)
|
||||
}
|
||||
|
||||
adopted := env.listSkills(t)
|
||||
adoptedManaged := mustFindSkillByPath(t, adopted, managedPath)
|
||||
if adoptedManaged.State != skillset.StateEffective {
|
||||
t.Fatalf("managed adopted skill state = %q, want %q", adoptedManaged.State, skillset.StateEffective)
|
||||
}
|
||||
adoptedCompat := mustFindSkillByPath(t, adopted, externalPath)
|
||||
if adoptedCompat.State != skillset.StateShadowed {
|
||||
t.Fatalf("compat adopted skill state = %q, want %q", adoptedCompat.State, skillset.StateShadowed)
|
||||
}
|
||||
|
||||
rec, err = env.callJSON(t, http.MethodPost, "/bots/:bot_id/container/skills/actions", SkillsActionRequest{
|
||||
Action: skillset.ActionDisable,
|
||||
TargetPath: managedPath,
|
||||
}, env.handler.ApplySkillAction)
|
||||
if err != nil {
|
||||
t.Fatalf("disable returned error: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("disable status = %d, want 200", rec.Code)
|
||||
}
|
||||
|
||||
disabled := env.listSkills(t)
|
||||
disabledManaged := mustFindSkillByPath(t, disabled, managedPath)
|
||||
if disabledManaged.State != skillset.StateDisabled {
|
||||
t.Fatalf("managed disabled skill state = %q, want %q", disabledManaged.State, skillset.StateDisabled)
|
||||
}
|
||||
disabledCompat := mustFindSkillByPath(t, disabled, externalPath)
|
||||
if disabledCompat.State != skillset.StateEffective {
|
||||
t.Fatalf("compat fallback state = %q, want %q", disabledCompat.State, skillset.StateEffective)
|
||||
}
|
||||
|
||||
rec, err = env.callJSON(t, http.MethodPost, "/bots/:bot_id/container/skills/actions", SkillsActionRequest{
|
||||
Action: skillset.ActionEnable,
|
||||
TargetPath: managedPath,
|
||||
}, env.handler.ApplySkillAction)
|
||||
if err != nil {
|
||||
t.Fatalf("enable returned error: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("enable status = %d, want 200", rec.Code)
|
||||
}
|
||||
|
||||
reenabled := env.listSkills(t)
|
||||
if got := mustFindSkillByPath(t, reenabled, managedPath).State; got != skillset.StateEffective {
|
||||
t.Fatalf("managed state after enable = %q, want %q", got, skillset.StateEffective)
|
||||
}
|
||||
if got := mustFindSkillByPath(t, reenabled, externalPath).State; got != skillset.StateShadowed {
|
||||
t.Fatalf("compat state after enable = %q, want %q", got, skillset.StateShadowed)
|
||||
}
|
||||
|
||||
rec, err = env.callJSON(t, http.MethodDelete, "/bots/:bot_id/container/skills", SkillsDeleteRequest{
|
||||
Names: []string{"alpha"},
|
||||
}, env.handler.DeleteSkills)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteSkills returned error: %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("DeleteSkills status = %d, want 200", rec.Code)
|
||||
}
|
||||
if _, err := os.Stat(env.localPath(managedPath)); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected managed skill to be removed, stat err=%v", err)
|
||||
}
|
||||
|
||||
deleted := env.listSkills(t)
|
||||
if len(deleted) != 1 {
|
||||
t.Fatalf("expected only compat skill after delete, got %d items", len(deleted))
|
||||
}
|
||||
if got := mustFindSkillByPath(t, deleted, externalPath).State; got != skillset.StateEffective {
|
||||
t.Fatalf("compat state after delete = %q, want %q", got, skillset.StateEffective)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSkillFile_EmptyBodyFallbacksToDescription(t *testing.T) {
|
||||
raw := "---\nname: hello-skill\ndescription: say hello\n---\n"
|
||||
got := parseSkillFile(raw, "fallback")
|
||||
func TestDeleteSkillsAPIRejectsExternalOnlySkill(t *testing.T) {
|
||||
env := newSkillsTestEnv(t)
|
||||
env.writeSkillFile(t, path.Join("/data/.agents/skills", "alpha", "SKILL.md"), managedSkillRaw("alpha", "Compat Alpha"))
|
||||
|
||||
if got.Content != "say hello" {
|
||||
t.Fatalf("expected content fallback to description, got %q", got.Content)
|
||||
_, err := env.callJSON(t, http.MethodDelete, "/bots/:bot_id/container/skills", SkillsDeleteRequest{
|
||||
Names: []string{"alpha"},
|
||||
}, env.handler.DeleteSkills)
|
||||
if err == nil {
|
||||
t.Fatal("expected deleting external-only skill to fail")
|
||||
}
|
||||
var httpErr *echo.HTTPError
|
||||
if !errors.As(err, &httpErr) {
|
||||
t.Fatalf("expected echo.HTTPError, got %T", err)
|
||||
}
|
||||
if httpErr.Code != http.StatusNotFound {
|
||||
t.Fatalf("delete external-only status = %d, want 404", httpErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertSkillsAPIRejectsTraversalName(t *testing.T) {
|
||||
env := newSkillsTestEnv(t)
|
||||
|
||||
_, err := env.callJSON(t, http.MethodPost, "/bots/:bot_id/container/skills", SkillsUpsertRequest{
|
||||
Skills: []string{"---\nname: ..\ndescription: Escape\n---\n\n# Escape"},
|
||||
}, env.handler.UpsertSkills)
|
||||
if err == nil {
|
||||
t.Fatal("expected upserting traversal skill name to fail")
|
||||
}
|
||||
var httpErr *echo.HTTPError
|
||||
if !errors.As(err, &httpErr) {
|
||||
t.Fatalf("expected echo.HTTPError, got %T", err)
|
||||
}
|
||||
if httpErr.Code != http.StatusBadRequest {
|
||||
t.Fatalf("upsert traversal status = %d, want 400", httpErr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSkillsUsesEffectiveSetAndPromptReflectsOverrideFallback(t *testing.T) {
|
||||
env := newSkillsTestEnv(t)
|
||||
managedPath := path.Join(skillset.ManagedDir(), "alpha", "SKILL.md")
|
||||
compatPath := path.Join("/data/.agents/skills", "alpha", "SKILL.md")
|
||||
env.writeSkillFile(t, managedPath, managedSkillRaw("alpha", "Managed Alpha"))
|
||||
env.writeSkillFile(t, compatPath, managedSkillRaw("alpha", "Compat Alpha"))
|
||||
env.writeSkillFile(t, path.Join("/data/.agents/skills", "beta", "SKILL.md"), managedSkillRaw("beta", "Compat Beta"))
|
||||
|
||||
loaded, err := env.handler.LoadSkills(context.Background(), env.botID)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSkills returned error: %v", err)
|
||||
}
|
||||
if len(loaded) != 2 {
|
||||
t.Fatalf("expected 2 effective skills, got %d", len(loaded))
|
||||
}
|
||||
if got := loaded[0].Name + ":" + loaded[0].Description + "|" + loaded[1].Name + ":" + loaded[1].Description; !strings.Contains(got, "alpha:Managed Alpha") {
|
||||
t.Fatalf("effective skills should include managed alpha, got %s", got)
|
||||
}
|
||||
promptBefore := promptFromLoadedSkills(loaded)
|
||||
if !strings.Contains(promptBefore, "Managed Alpha") {
|
||||
t.Fatalf("prompt should include managed alpha description:\n%s", promptBefore)
|
||||
}
|
||||
if strings.Contains(promptBefore, "Compat Alpha") {
|
||||
t.Fatalf("prompt should not include shadowed compat alpha:\n%s", promptBefore)
|
||||
}
|
||||
|
||||
client, err := env.handler.manager.MCPClient(context.Background(), env.botID)
|
||||
if err != nil {
|
||||
t.Fatalf("get bridge client: %v", err)
|
||||
}
|
||||
if err := skillset.ApplyAction(context.Background(), client, skillset.ActionRequest{
|
||||
Action: skillset.ActionDisable,
|
||||
TargetPath: managedPath,
|
||||
}); err != nil {
|
||||
t.Fatalf("disable managed alpha via skillset.ApplyAction: %v", err)
|
||||
}
|
||||
|
||||
fallback, err := env.handler.LoadSkills(context.Background(), env.botID)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSkills after disable returned error: %v", err)
|
||||
}
|
||||
if len(fallback) != 2 {
|
||||
t.Fatalf("expected 2 effective skills after disable, got %d", len(fallback))
|
||||
}
|
||||
alphaFallback := mustFindLoadedSkillByName(t, fallback, "alpha")
|
||||
if alphaFallback.Description != "Compat Alpha" {
|
||||
t.Fatalf("effective alpha description after disable = %q, want %q", alphaFallback.Description, "Compat Alpha")
|
||||
}
|
||||
promptAfter := promptFromLoadedSkills(fallback)
|
||||
if !strings.Contains(promptAfter, "Compat Alpha") {
|
||||
t.Fatalf("prompt should include compat alpha after fallback:\n%s", promptAfter)
|
||||
}
|
||||
if strings.Contains(promptAfter, "Managed Alpha") {
|
||||
t.Fatalf("prompt should not include disabled managed alpha after fallback:\n%s", promptAfter)
|
||||
}
|
||||
}
|
||||
|
||||
type skillsTestEnv struct {
|
||||
handler *ContainerdHandler
|
||||
dataRoot string
|
||||
botID string
|
||||
userID string
|
||||
}
|
||||
|
||||
func newSkillsTestEnv(t *testing.T) *skillsTestEnv {
|
||||
t.Helper()
|
||||
|
||||
dataRoot, err := newSkillsTestDataRoot()
|
||||
if err != nil {
|
||||
t.Fatalf("create temp data root: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.RemoveAll(dataRoot) })
|
||||
userID := "00000000-0000-0000-0000-000000000001"
|
||||
botID := "00000000-0000-0000-0000-000000000010"
|
||||
startSkillsTestBridgeServer(t, dataRoot, botID)
|
||||
|
||||
cfg := config.WorkspaceConfig{DataRoot: dataRoot}
|
||||
db := &skillsTestDB{userID: userID, botID: botID}
|
||||
manager := workspace.NewManager(slog.Default(), nil, cfg, "", nil)
|
||||
handler := NewContainerdHandler(
|
||||
slog.Default(),
|
||||
manager,
|
||||
cfg,
|
||||
"",
|
||||
bots.NewService(slog.Default(), sqlc.New(db)),
|
||||
accounts.NewService(slog.Default(), sqlc.New(db)),
|
||||
nil,
|
||||
)
|
||||
|
||||
return &skillsTestEnv{
|
||||
handler: handler,
|
||||
dataRoot: dataRoot,
|
||||
botID: botID,
|
||||
userID: userID,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *skillsTestEnv) callJSON(t *testing.T, method, routePath string, body any, fn func(echo.Context) error) (*httptest.ResponseRecorder, error) {
|
||||
t.Helper()
|
||||
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal request body: %v", err)
|
||||
}
|
||||
bodyReader = strings.NewReader(string(data))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(method, routePath, bodyReader)
|
||||
if body != nil {
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
ctx := echo.New().NewContext(req, rec)
|
||||
ctx.SetPath(routePath)
|
||||
ctx.SetParamNames("bot_id")
|
||||
ctx.SetParamValues(e.botID)
|
||||
ctx.Set("user", &jwt.Token{
|
||||
Valid: true,
|
||||
Claims: jwt.MapClaims{"user_id": e.userID, "sub": e.userID},
|
||||
})
|
||||
|
||||
return rec, fn(ctx)
|
||||
}
|
||||
|
||||
func (e *skillsTestEnv) listSkills(t *testing.T) []SkillItem {
|
||||
t.Helper()
|
||||
rec, err := e.callJSON(t, http.MethodGet, "/bots/:bot_id/container/skills", nil, e.handler.ListSkills)
|
||||
if err != nil {
|
||||
t.Fatalf("ListSkills returned error: %v", err)
|
||||
}
|
||||
var resp SkillsResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode ListSkills response: %v", err)
|
||||
}
|
||||
return resp.Skills
|
||||
}
|
||||
|
||||
func (e *skillsTestEnv) writeSkillFile(t *testing.T, containerPath, raw string) {
|
||||
t.Helper()
|
||||
local := e.localPath(containerPath)
|
||||
if err := os.MkdirAll(filepath.Dir(local), 0o750); err != nil {
|
||||
t.Fatalf("mkdir %s: %v", filepath.Dir(local), err)
|
||||
}
|
||||
//nolint:gosec // test-only temp workspace path
|
||||
if err := os.WriteFile(local, []byte(raw), 0o600); err != nil {
|
||||
t.Fatalf("write %s: %v", local, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *skillsTestEnv) localPath(containerPath string) string {
|
||||
clean := path.Clean("/" + strings.TrimSpace(containerPath))
|
||||
if clean == "/" {
|
||||
return e.dataRoot
|
||||
}
|
||||
return filepath.Join(e.dataRoot, filepath.FromSlash(strings.TrimPrefix(clean, "/")))
|
||||
}
|
||||
|
||||
func newSkillsTestDataRoot() (string, error) {
|
||||
var lastErr error
|
||||
for _, dir := range []string{"/tmp", ""} {
|
||||
dataRoot, err := os.MkdirTemp(dir, "mh-sk-")
|
||||
if err == nil {
|
||||
return dataRoot, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
type skillsTestDB struct {
|
||||
userID string
|
||||
botID string
|
||||
}
|
||||
|
||||
func (*skillsTestDB) Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) {
|
||||
return pgconn.CommandTag{}, nil
|
||||
}
|
||||
|
||||
func (*skillsTestDB) Query(context.Context, string, ...interface{}) (pgx.Rows, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *skillsTestDB) QueryRow(_ context.Context, sql string, _ ...interface{}) pgx.Row {
|
||||
switch {
|
||||
case strings.Contains(sql, "FROM users WHERE id = $1"):
|
||||
return makeUserRow(mustParseUUID(d.userID), "user")
|
||||
case strings.Contains(sql, "FROM bots"):
|
||||
return makeBotRow(mustParseUUID(d.botID), mustParseUUID(d.userID))
|
||||
default:
|
||||
return &skillsTestRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }}
|
||||
}
|
||||
}
|
||||
|
||||
type skillsTestRow struct {
|
||||
scanFunc func(dest ...any) error
|
||||
}
|
||||
|
||||
func (r *skillsTestRow) Scan(dest ...any) error {
|
||||
return r.scanFunc(dest...)
|
||||
}
|
||||
|
||||
func makeUserRow(userID pgtype.UUID, role string) *skillsTestRow {
|
||||
return &skillsTestRow{
|
||||
scanFunc: func(dest ...any) error {
|
||||
if len(dest) < 14 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
*dest[0].(*pgtype.UUID) = userID
|
||||
*dest[1].(*pgtype.Text) = pgtype.Text{String: "owner", Valid: true}
|
||||
*dest[2].(*pgtype.Text) = pgtype.Text{}
|
||||
*dest[3].(*pgtype.Text) = pgtype.Text{}
|
||||
*dest[4].(*string) = role
|
||||
*dest[5].(*pgtype.Text) = pgtype.Text{String: "Owner", Valid: true}
|
||||
*dest[6].(*pgtype.Text) = pgtype.Text{}
|
||||
*dest[7].(*string) = "UTC"
|
||||
*dest[8].(*pgtype.Text) = pgtype.Text{}
|
||||
*dest[9].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
||||
*dest[10].(*bool) = true
|
||||
*dest[11].(*[]byte) = []byte(`{}`)
|
||||
*dest[12].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
||||
*dest[13].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeBotRow(botID, ownerUserID pgtype.UUID) *skillsTestRow {
|
||||
return &skillsTestRow{
|
||||
scanFunc: func(dest ...any) error {
|
||||
if len(dest) < 23 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
*dest[0].(*pgtype.UUID) = botID
|
||||
*dest[1].(*pgtype.UUID) = ownerUserID
|
||||
*dest[2].(*pgtype.Text) = pgtype.Text{String: "test-bot", Valid: true}
|
||||
*dest[3].(*pgtype.Text) = pgtype.Text{}
|
||||
*dest[4].(*pgtype.Text) = pgtype.Text{}
|
||||
*dest[5].(*bool) = true
|
||||
*dest[6].(*string) = bots.BotStatusReady
|
||||
*dest[7].(*string) = "en"
|
||||
*dest[8].(*bool) = false
|
||||
*dest[9].(*string) = "medium"
|
||||
*dest[10].(*pgtype.UUID) = pgtype.UUID{}
|
||||
*dest[11].(*pgtype.UUID) = pgtype.UUID{}
|
||||
*dest[12].(*pgtype.UUID) = pgtype.UUID{}
|
||||
*dest[13].(*bool) = false
|
||||
*dest[14].(*int32) = 30
|
||||
*dest[15].(*string) = ""
|
||||
*dest[16].(*bool) = false
|
||||
*dest[17].(*int32) = 100000
|
||||
*dest[18].(*int32) = 80
|
||||
*dest[19].(*pgtype.UUID) = pgtype.UUID{}
|
||||
*dest[20].(*[]byte) = []byte(`{}`)
|
||||
*dest[21].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
||||
*dest[22].(*pgtype.Timestamptz) = pgtype.Timestamptz{}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseUUID(s string) pgtype.UUID {
|
||||
var u pgtype.UUID
|
||||
if err := u.Scan(s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
type skillsTestBridgeServer struct {
|
||||
pb.UnimplementedContainerServiceServer
|
||||
root string
|
||||
}
|
||||
|
||||
func startSkillsTestBridgeServer(t *testing.T, dataRoot, botID string) {
|
||||
t.Helper()
|
||||
|
||||
socketPath := filepath.Join(dataRoot, "run", botID, "bridge.sock")
|
||||
if err := os.MkdirAll(filepath.Dir(socketPath), 0o750); err != nil {
|
||||
t.Fatalf("mkdir socket dir: %v", err)
|
||||
}
|
||||
var lc net.ListenConfig
|
||||
lis, err := lc.Listen(context.Background(), "unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("listen unix socket: %v", err)
|
||||
}
|
||||
|
||||
srv := grpc.NewServer()
|
||||
pb.RegisterContainerServiceServer(srv, &skillsTestBridgeServer{root: dataRoot})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_ = srv.Serve(lis)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
srv.Stop()
|
||||
_ = lis.Close()
|
||||
<-done
|
||||
})
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) ListDir(_ context.Context, req *pb.ListDirRequest) (*pb.ListDirResponse, error) {
|
||||
containerPath, localPath := s.resolvePath(req.GetPath())
|
||||
entries, err := os.ReadDir(localPath)
|
||||
if err != nil {
|
||||
return nil, toStatusError(err, req.GetPath())
|
||||
}
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() })
|
||||
resp := make([]*pb.FileEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "stat %s: %v", entry.Name(), err)
|
||||
}
|
||||
entryPath := path.Join(containerPath, entry.Name())
|
||||
if containerPath == "/" {
|
||||
entryPath = "/" + entry.Name()
|
||||
}
|
||||
resp = append(resp, &pb.FileEntry{
|
||||
Path: entryPath,
|
||||
IsDir: entry.IsDir(),
|
||||
Size: info.Size(),
|
||||
Mode: info.Mode().String(),
|
||||
ModTime: info.ModTime().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
if len(resp) > 1<<31-1 {
|
||||
return nil, status.Error(codes.Internal, "too many entries")
|
||||
}
|
||||
//nolint:gosec // len(resp) is bounds-checked just above
|
||||
totalCount := int32(len(resp))
|
||||
return &pb.ListDirResponse{
|
||||
Entries: resp,
|
||||
TotalCount: totalCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error {
|
||||
_, localPath := s.resolvePath(req.GetPath())
|
||||
//nolint:gosec // test-only temp workspace path
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return toStatusError(err, req.GetPath())
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return stream.Send(&pb.DataChunk{Data: data})
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) WriteRaw(stream pb.ContainerService_WriteRawServer) error {
|
||||
var containerPath string
|
||||
var data []byte
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if containerPath == "" {
|
||||
containerPath = chunk.GetPath()
|
||||
}
|
||||
data = append(data, chunk.GetData()...)
|
||||
}
|
||||
if strings.TrimSpace(containerPath) == "" {
|
||||
return status.Error(codes.InvalidArgument, "path is required")
|
||||
}
|
||||
_, localPath := s.resolvePath(containerPath)
|
||||
if err := os.MkdirAll(filepath.Dir(localPath), 0o750); err != nil {
|
||||
return status.Errorf(codes.Internal, "mkdir parent for %s: %v", containerPath, err)
|
||||
}
|
||||
//nolint:gosec // test-only temp workspace path
|
||||
if err := os.WriteFile(localPath, data, 0o600); err != nil {
|
||||
return status.Errorf(codes.Internal, "write %s: %v", containerPath, err)
|
||||
}
|
||||
return stream.SendAndClose(&pb.WriteRawResponse{BytesWritten: int64(len(data))})
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) WriteFile(_ context.Context, req *pb.WriteFileRequest) (*pb.WriteFileResponse, error) {
|
||||
_, localPath := s.resolvePath(req.GetPath())
|
||||
if err := os.MkdirAll(filepath.Dir(localPath), 0o750); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "mkdir parent for %s: %v", req.GetPath(), err)
|
||||
}
|
||||
//nolint:gosec // test-only temp workspace path
|
||||
if err := os.WriteFile(localPath, req.GetContent(), 0o600); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "write %s: %v", req.GetPath(), err)
|
||||
}
|
||||
return &pb.WriteFileResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) Stat(_ context.Context, req *pb.StatRequest) (*pb.StatResponse, error) {
|
||||
containerPath, localPath := s.resolvePath(req.GetPath())
|
||||
info, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
return nil, toStatusError(err, req.GetPath())
|
||||
}
|
||||
return &pb.StatResponse{Entry: &pb.FileEntry{
|
||||
Path: containerPath,
|
||||
IsDir: info.IsDir(),
|
||||
Size: info.Size(),
|
||||
Mode: info.Mode().String(),
|
||||
ModTime: info.ModTime().UTC().Format(time.RFC3339),
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) Mkdir(_ context.Context, req *pb.MkdirRequest) (*pb.MkdirResponse, error) {
|
||||
_, localPath := s.resolvePath(req.GetPath())
|
||||
if err := os.MkdirAll(localPath, 0o750); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "mkdir %s: %v", req.GetPath(), err)
|
||||
}
|
||||
return &pb.MkdirResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) DeleteFile(_ context.Context, req *pb.DeleteFileRequest) (*pb.DeleteFileResponse, error) {
|
||||
_, localPath := s.resolvePath(req.GetPath())
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
return nil, toStatusError(err, req.GetPath())
|
||||
}
|
||||
var err error
|
||||
if req.GetRecursive() {
|
||||
err = os.RemoveAll(localPath)
|
||||
} else {
|
||||
err = os.Remove(localPath)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "delete %s: %v", req.GetPath(), err)
|
||||
}
|
||||
return &pb.DeleteFileResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *skillsTestBridgeServer) resolvePath(containerPath string) (string, string) {
|
||||
clean := path.Clean("/" + strings.TrimSpace(containerPath))
|
||||
if clean == "/" {
|
||||
return clean, s.root
|
||||
}
|
||||
return clean, filepath.Join(s.root, filepath.FromSlash(strings.TrimPrefix(clean, "/")))
|
||||
}
|
||||
|
||||
func toStatusError(err error, containerPath string) error {
|
||||
if os.IsNotExist(err) {
|
||||
return status.Errorf(codes.NotFound, "path not found: %s", containerPath)
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return status.Errorf(codes.PermissionDenied, "permission denied: %s", containerPath)
|
||||
}
|
||||
return status.Errorf(codes.Internal, "%v", err)
|
||||
}
|
||||
|
||||
func mustFindSkillByPath(t *testing.T, items []SkillItem, sourcePath string) SkillItem {
|
||||
t.Helper()
|
||||
for _, item := range items {
|
||||
if item.SourcePath == sourcePath {
|
||||
return item
|
||||
}
|
||||
}
|
||||
t.Fatalf("skill with source path %q not found in %+v", sourcePath, items)
|
||||
return SkillItem{}
|
||||
}
|
||||
|
||||
func mustFindLoadedSkillByName(t *testing.T, items []SkillItem, name string) SkillItem {
|
||||
t.Helper()
|
||||
for _, item := range items {
|
||||
if item.Name == name {
|
||||
return item
|
||||
}
|
||||
}
|
||||
t.Fatalf("loaded skill %q not found in %+v", name, items)
|
||||
return SkillItem{}
|
||||
}
|
||||
|
||||
func promptFromLoadedSkills(items []SkillItem) string {
|
||||
skills := make([]agent.SkillEntry, 0, len(items))
|
||||
for _, item := range items {
|
||||
skills = append(skills, agent.SkillEntry{
|
||||
Name: item.Name,
|
||||
Description: item.Description,
|
||||
Content: item.Content,
|
||||
Metadata: item.Metadata,
|
||||
})
|
||||
}
|
||||
return agent.GenerateSystemPrompt(agent.SystemPromptParams{
|
||||
SessionType: "chat",
|
||||
Skills: skills,
|
||||
Now: time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC),
|
||||
Timezone: "UTC",
|
||||
})
|
||||
}
|
||||
|
||||
func managedSkillRaw(name, description string) string {
|
||||
return "---\nname: " + name + "\ndescription: " + description + "\n---\n\n# " + description + "\n"
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
skillset "github.com/memohai/memoh/internal/skills"
|
||||
"github.com/memohai/memoh/internal/workspace"
|
||||
)
|
||||
|
||||
@@ -276,7 +277,7 @@ func (h *SupermarketHandler) InstallSkill(c echo.Context) error {
|
||||
}
|
||||
defer func() { _ = gz.Close() }()
|
||||
|
||||
skillDir := path.Join(skillsDirPath, skillID)
|
||||
skillDir := path.Join(skillset.ManagedDir(), skillID)
|
||||
if err := client.Mkdir(ctx, skillDir); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("mkdir failed: %v", err))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,509 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||
)
|
||||
|
||||
const (
|
||||
ManagedDirPath = config.DefaultDataMount + "/skills"
|
||||
LegacyDirPath = config.DefaultDataMount + "/.skills"
|
||||
IndexDirPath = config.DefaultDataMount + "/.memoh/skills"
|
||||
IndexFilePath = IndexDirPath + "/index.json"
|
||||
|
||||
SourceKindManaged = "managed"
|
||||
SourceKindLegacy = "legacy"
|
||||
SourceKindCompat = "compat"
|
||||
|
||||
StateEffective = "effective"
|
||||
StateShadowed = "shadowed"
|
||||
StateDisabled = "disabled"
|
||||
|
||||
ActionAdopt = "adopt"
|
||||
ActionDisable = "disable"
|
||||
ActionEnable = "enable"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Raw string `json:"raw"`
|
||||
SourcePath string `json:"source_path,omitempty"`
|
||||
SourceRoot string `json:"source_root,omitempty"`
|
||||
SourceKind string `json:"source_kind,omitempty"`
|
||||
Managed bool `json:"managed,omitempty"`
|
||||
State string `json:"state,omitempty"`
|
||||
ShadowedBy string `json:"shadowed_by,omitempty"`
|
||||
}
|
||||
|
||||
type ActionRequest struct {
|
||||
Action string
|
||||
TargetPath string
|
||||
}
|
||||
|
||||
type Parsed struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
type Root struct {
|
||||
Path string
|
||||
Kind string
|
||||
Managed bool
|
||||
}
|
||||
|
||||
type fileClient interface {
|
||||
ListDirAll(ctx context.Context, path string, recursive bool) ([]*pb.FileEntry, error)
|
||||
ReadRaw(ctx context.Context, path string) (io.ReadCloser, error)
|
||||
WriteRaw(ctx context.Context, path string, r io.Reader) (int64, error)
|
||||
Mkdir(ctx context.Context, path string) error
|
||||
}
|
||||
|
||||
type indexState struct {
|
||||
Version int `json:"version"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
Overrides map[string]indexOverride `json:"overrides,omitempty"`
|
||||
Items []indexedItem `json:"items,omitempty"`
|
||||
}
|
||||
|
||||
type indexOverride struct {
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
type indexedItem struct {
|
||||
Name string `json:"name"`
|
||||
SourcePath string `json:"source_path"`
|
||||
SourceKind string `json:"source_kind"`
|
||||
Managed bool `json:"managed"`
|
||||
State string `json:"state"`
|
||||
ShadowedBy string `json:"shadowed_by,omitempty"`
|
||||
ContentHash string `json:"content_hash,omitempty"`
|
||||
LastSeenAt string `json:"last_seen_at,omitempty"`
|
||||
}
|
||||
|
||||
func ManagedDir() string {
|
||||
return ManagedDirPath
|
||||
}
|
||||
|
||||
func ManagedSkillDirForName(name string) (string, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if !IsValidName(name) {
|
||||
return "", bridge.ErrBadRequest
|
||||
}
|
||||
|
||||
dirPath := path.Clean(path.Join(ManagedDirPath, name))
|
||||
if dirPath == ManagedDirPath || !strings.HasPrefix(dirPath, ManagedDirPath+"/") {
|
||||
return "", bridge.ErrBadRequest
|
||||
}
|
||||
return dirPath, nil
|
||||
}
|
||||
|
||||
func ContainerEnv() []string {
|
||||
return []string{
|
||||
"HOME=" + config.DefaultDataMount,
|
||||
"XDG_CONFIG_HOME=" + path.Join(config.DefaultDataMount, ".config"),
|
||||
"XDG_DATA_HOME=" + path.Join(config.DefaultDataMount, ".local", "share"),
|
||||
"XDG_CACHE_HOME=" + path.Join(config.DefaultDataMount, ".cache"),
|
||||
}
|
||||
}
|
||||
|
||||
func DiscoveryRoots() []Root {
|
||||
return []Root{
|
||||
{Path: ManagedDirPath, Kind: SourceKindManaged, Managed: true},
|
||||
{Path: LegacyDirPath, Kind: SourceKindLegacy, Managed: false},
|
||||
{Path: path.Join(config.DefaultDataMount, ".agents", "skills"), Kind: SourceKindCompat, Managed: false},
|
||||
{Path: path.Join("/root", ".agents", "skills"), Kind: SourceKindCompat, Managed: false},
|
||||
}
|
||||
}
|
||||
|
||||
func List(ctx context.Context, client fileClient) ([]Entry, error) {
|
||||
idx := readIndex(ctx, client)
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
resolved := resolve(items, idx.Overrides)
|
||||
writeIndex(ctx, client, idx.withItems(resolved))
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func LoadEffective(ctx context.Context, client fileClient) ([]Entry, error) {
|
||||
items, err := List(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]Entry, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.State == StateEffective {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func ApplyAction(ctx context.Context, client fileClient, req ActionRequest) error {
|
||||
targetPath := strings.TrimSpace(req.TargetPath)
|
||||
if targetPath == "" {
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
|
||||
switch strings.TrimSpace(req.Action) {
|
||||
case ActionDisable:
|
||||
idx := readIndex(ctx, client)
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
if !containsSourcePath(items, targetPath) {
|
||||
return bridge.ErrNotFound
|
||||
}
|
||||
if idx.Overrides == nil {
|
||||
idx.Overrides = make(map[string]indexOverride)
|
||||
}
|
||||
idx.Overrides[targetPath] = indexOverride{Disabled: true}
|
||||
writeIndex(ctx, client, idx.withItems(resolve(items, idx.Overrides)))
|
||||
return nil
|
||||
case ActionEnable:
|
||||
idx := readIndex(ctx, client)
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
if !containsSourcePath(items, targetPath) {
|
||||
return bridge.ErrNotFound
|
||||
}
|
||||
delete(idx.Overrides, targetPath)
|
||||
writeIndex(ctx, client, idx.withItems(resolve(items, idx.Overrides)))
|
||||
return nil
|
||||
case ActionAdopt:
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
target, ok := findBySourcePath(items, targetPath)
|
||||
if !ok {
|
||||
return bridge.ErrNotFound
|
||||
}
|
||||
if target.Managed {
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
for _, item := range items {
|
||||
if item.Name == target.Name && item.Managed {
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
}
|
||||
dirPath, err := ManagedSkillDirForName(target.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.Mkdir(ctx, dirPath); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := client.WriteRaw(ctx, path.Join(dirPath, "SKILL.md"), strings.NewReader(target.Raw)); err != nil {
|
||||
return err
|
||||
}
|
||||
idx := readIndex(ctx, client)
|
||||
writeIndex(ctx, client, idx.withItems(resolve(scan(ctx, client, DiscoveryRoots()), idx.Overrides)))
|
||||
return nil
|
||||
default:
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
func ParseFile(raw string, fallbackName string) Parsed {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
result := Parsed{
|
||||
Name: strings.TrimSpace(fallbackName),
|
||||
Content: trimmed,
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "---") {
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
|
||||
rest := trimmed[3:]
|
||||
rest = strings.TrimLeft(rest, " \t")
|
||||
if len(rest) > 0 && rest[0] == '\n' {
|
||||
rest = rest[1:]
|
||||
} else if len(rest) > 1 && rest[0] == '\r' && rest[1] == '\n' {
|
||||
rest = rest[2:]
|
||||
}
|
||||
closingIdx := strings.Index(rest, "\n---")
|
||||
if closingIdx < 0 {
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
|
||||
frontmatterRaw := rest[:closingIdx]
|
||||
body := rest[closingIdx+4:]
|
||||
body = strings.TrimLeft(body, "\r\n")
|
||||
result.Content = body
|
||||
|
||||
var fm struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
}
|
||||
if err := yaml.Unmarshal([]byte(frontmatterRaw), &fm); err != nil {
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
if strings.TrimSpace(fm.Name) != "" {
|
||||
result.Name = strings.TrimSpace(fm.Name)
|
||||
}
|
||||
result.Description = strings.TrimSpace(fm.Description)
|
||||
result.Metadata = fm.Metadata
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
|
||||
func IsValidName(name string) bool {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if name == "." || name == ".." {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(name, ".") || strings.Contains(name, "..") {
|
||||
return false
|
||||
}
|
||||
for _, r := range name {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeParsed(skill Parsed) Parsed {
|
||||
if strings.TrimSpace(skill.Name) == "" {
|
||||
skill.Name = "default"
|
||||
}
|
||||
skill.Name = strings.TrimSpace(skill.Name)
|
||||
skill.Description = strings.TrimSpace(skill.Description)
|
||||
skill.Content = strings.TrimSpace(skill.Content)
|
||||
if skill.Description == "" {
|
||||
skill.Description = skill.Name
|
||||
}
|
||||
if skill.Content == "" {
|
||||
skill.Content = skill.Description
|
||||
}
|
||||
return skill
|
||||
}
|
||||
|
||||
func scan(ctx context.Context, client fileClient, roots []Root) []Entry {
|
||||
items := make([]Entry, 0, 16)
|
||||
for _, root := range roots {
|
||||
entries, err := client.ListDirAll(ctx, root.Path, false)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
slices.SortFunc(entries, func(a, b *pb.FileEntry) int {
|
||||
return strings.Compare(a.GetPath(), b.GetPath())
|
||||
})
|
||||
for _, entry := range entries {
|
||||
if !entry.GetIsDir() {
|
||||
if path.Base(entry.GetPath()) != "SKILL.md" {
|
||||
continue
|
||||
}
|
||||
filePath := path.Join(root.Path, "SKILL.md")
|
||||
raw, err := readRawFile(ctx, client, filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
parsed := ParseFile(raw, "default")
|
||||
items = append(items, entryFromParsed(parsed, raw, root, filePath))
|
||||
continue
|
||||
}
|
||||
|
||||
name := path.Base(entry.GetPath())
|
||||
if name == "" || name == "." {
|
||||
continue
|
||||
}
|
||||
filePath := path.Join(root.Path, name, "SKILL.md")
|
||||
raw, err := readRawFile(ctx, client, filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
parsed := ParseFile(raw, name)
|
||||
items = append(items, entryFromParsed(parsed, raw, root, filePath))
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func resolve(items []Entry, overrides map[string]indexOverride) []Entry {
|
||||
byName := make(map[string][]Entry, len(items))
|
||||
for _, item := range items {
|
||||
byName[item.Name] = append(byName[item.Name], item)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(byName))
|
||||
for name := range byName {
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
|
||||
out := make([]Entry, 0, len(items))
|
||||
for _, name := range names {
|
||||
group := byName[name]
|
||||
var effectivePath string
|
||||
for i := range group {
|
||||
if overrides[group[i].SourcePath].Disabled {
|
||||
group[i].State = StateDisabled
|
||||
out = append(out, group[i])
|
||||
continue
|
||||
}
|
||||
if effectivePath == "" {
|
||||
group[i].State = StateEffective
|
||||
effectivePath = group[i].SourcePath
|
||||
out = append(out, group[i])
|
||||
continue
|
||||
}
|
||||
group[i].State = StateShadowed
|
||||
group[i].ShadowedBy = effectivePath
|
||||
out = append(out, group[i])
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(out, func(a, b Entry) int {
|
||||
if cmp := strings.Compare(a.Name, b.Name); cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
if cmp := stateRank(a.State) - stateRank(b.State); cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
if a.Managed != b.Managed {
|
||||
if a.Managed {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.SourcePath, b.SourcePath)
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func stateRank(state string) int {
|
||||
switch state {
|
||||
case StateEffective:
|
||||
return 0
|
||||
case StateShadowed:
|
||||
return 1
|
||||
case StateDisabled:
|
||||
return 2
|
||||
default:
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
func entryFromParsed(parsed Parsed, raw string, root Root, sourcePath string) Entry {
|
||||
return Entry{
|
||||
Name: parsed.Name,
|
||||
Description: parsed.Description,
|
||||
Content: parsed.Content,
|
||||
Metadata: parsed.Metadata,
|
||||
Raw: raw,
|
||||
SourcePath: sourcePath,
|
||||
SourceRoot: root.Path,
|
||||
SourceKind: root.Kind,
|
||||
Managed: root.Managed,
|
||||
}
|
||||
}
|
||||
|
||||
func readRawFile(ctx context.Context, client fileClient, filePath string) (string, error) {
|
||||
rc, err := client.ReadRaw(ctx, filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = rc.Close() }()
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func readIndex(ctx context.Context, client fileClient) indexState {
|
||||
rc, err := client.ReadRaw(ctx, IndexFilePath)
|
||||
if err != nil {
|
||||
return indexState{Version: 1, Overrides: make(map[string]indexOverride)}
|
||||
}
|
||||
defer func() { _ = rc.Close() }()
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
if err != nil || len(data) == 0 {
|
||||
return indexState{Version: 1, Overrides: make(map[string]indexOverride)}
|
||||
}
|
||||
|
||||
var idx indexState
|
||||
if err := json.Unmarshal(data, &idx); err != nil {
|
||||
return indexState{Version: 1, Overrides: make(map[string]indexOverride)}
|
||||
}
|
||||
if idx.Version == 0 {
|
||||
idx.Version = 1
|
||||
}
|
||||
if idx.Overrides == nil {
|
||||
idx.Overrides = make(map[string]indexOverride)
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func writeIndex(ctx context.Context, client fileClient, idx indexState) {
|
||||
if err := client.Mkdir(ctx, IndexDirPath); err != nil {
|
||||
return
|
||||
}
|
||||
data, err := json.MarshalIndent(idx, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = client.WriteRaw(ctx, IndexFilePath, strings.NewReader(string(data)))
|
||||
}
|
||||
|
||||
func (i indexState) withItems(items []Entry) indexState {
|
||||
if i.Version == 0 {
|
||||
i.Version = 1
|
||||
}
|
||||
if i.Overrides == nil {
|
||||
i.Overrides = make(map[string]indexOverride)
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
i.UpdatedAt = now
|
||||
i.Items = make([]indexedItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
sum := sha256.Sum256([]byte(item.Raw))
|
||||
i.Items = append(i.Items, indexedItem{
|
||||
Name: item.Name,
|
||||
SourcePath: item.SourcePath,
|
||||
SourceKind: item.SourceKind,
|
||||
Managed: item.Managed,
|
||||
State: item.State,
|
||||
ShadowedBy: item.ShadowedBy,
|
||||
ContentHash: hex.EncodeToString(sum[:]),
|
||||
LastSeenAt: now,
|
||||
})
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func containsSourcePath(items []Entry, target string) bool {
|
||||
_, ok := findBySourcePath(items, target)
|
||||
return ok
|
||||
}
|
||||
|
||||
func findBySourcePath(items []Entry, target string) (Entry, bool) {
|
||||
for _, item := range items {
|
||||
if item.SourcePath == target {
|
||||
return item, true
|
||||
}
|
||||
}
|
||||
return Entry{}, false
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||
)
|
||||
|
||||
func TestParseFileFallbacks(t *testing.T) {
|
||||
raw := "# Use this skill\n\nDo something useful."
|
||||
got := ParseFile(raw, "plain-skill")
|
||||
|
||||
if got.Name != "plain-skill" {
|
||||
t.Fatalf("expected name plain-skill, got %q", got.Name)
|
||||
}
|
||||
if got.Description != "plain-skill" {
|
||||
t.Fatalf("expected description plain-skill, got %q", got.Description)
|
||||
}
|
||||
if got.Content != raw {
|
||||
t.Fatalf("expected content to keep original markdown, got %q", got.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSupportsDisabledFallbackAndShadowing(t *testing.T) {
|
||||
items := []Entry{
|
||||
{Name: "alpha", SourcePath: "/data/skills/alpha/SKILL.md", Managed: true, SourceKind: SourceKindManaged},
|
||||
{Name: "alpha", SourcePath: "/data/.agents/skills/alpha/SKILL.md", SourceKind: SourceKindCompat},
|
||||
{Name: "beta", SourcePath: "/data/.agents/skills/beta/SKILL.md", SourceKind: SourceKindCompat},
|
||||
}
|
||||
|
||||
resolved := resolve(items, map[string]indexOverride{
|
||||
"/data/skills/alpha/SKILL.md": {Disabled: true},
|
||||
})
|
||||
|
||||
managedAlpha, ok := findBySourcePath(resolved, "/data/skills/alpha/SKILL.md")
|
||||
if !ok {
|
||||
t.Fatalf("managed alpha not found in resolved items")
|
||||
}
|
||||
if managedAlpha.State != StateDisabled {
|
||||
t.Fatalf("managed alpha state = %q, want disabled", managedAlpha.State)
|
||||
}
|
||||
compatAlpha, ok := findBySourcePath(resolved, "/data/.agents/skills/alpha/SKILL.md")
|
||||
if !ok {
|
||||
t.Fatalf("compat alpha not found in resolved items")
|
||||
}
|
||||
if compatAlpha.State != StateEffective {
|
||||
t.Fatalf("compat alpha state = %q, want effective", compatAlpha.State)
|
||||
}
|
||||
beta, ok := findBySourcePath(resolved, "/data/.agents/skills/beta/SKILL.md")
|
||||
if !ok {
|
||||
t.Fatalf("beta not found in resolved items")
|
||||
}
|
||||
if beta.State != StateEffective {
|
||||
t.Fatalf("beta state = %q, want effective", beta.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListReadsFullRawContentAndWritesIndex(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
client.listings[ManagedDirPath] = []*pb.FileEntry{{Path: "alpha", IsDir: true}}
|
||||
client.files[pathJoin(ManagedDirPath, "alpha", "SKILL.md")] = "---\nname: alpha\ndescription: Alpha\n---\n\n" + strings.Repeat("A", 7000)
|
||||
|
||||
items, err := List(context.Background(), client)
|
||||
if err != nil {
|
||||
t.Fatalf("List returned error: %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 item, got %d", len(items))
|
||||
}
|
||||
if len(items[0].Raw) <= 7000 {
|
||||
t.Fatalf("expected full raw content, got len=%d", len(items[0].Raw))
|
||||
}
|
||||
if _, ok := client.files[IndexFilePath]; !ok {
|
||||
t.Fatalf("expected index file to be written")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyActionAdoptAndDisable(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
externalPath := pathJoin("/data/.agents/skills", "alpha", "SKILL.md")
|
||||
client.listings["/data/.agents/skills"] = []*pb.FileEntry{{Path: "alpha", IsDir: true}}
|
||||
client.files[externalPath] = "---\nname: alpha\ndescription: Alpha\n---\n\n# Alpha"
|
||||
|
||||
if err := ApplyAction(context.Background(), client, ActionRequest{
|
||||
Action: ActionAdopt,
|
||||
TargetPath: externalPath,
|
||||
}); err != nil {
|
||||
t.Fatalf("adopt returned error: %v", err)
|
||||
}
|
||||
if _, ok := client.files[pathJoin(ManagedDirPath, "alpha", "SKILL.md")]; !ok {
|
||||
t.Fatalf("expected managed copy after adopt")
|
||||
}
|
||||
|
||||
if err := ApplyAction(context.Background(), client, ActionRequest{
|
||||
Action: ActionDisable,
|
||||
TargetPath: externalPath,
|
||||
}); err != nil {
|
||||
t.Fatalf("disable returned error: %v", err)
|
||||
}
|
||||
idx := readIndex(context.Background(), client)
|
||||
if !idx.Overrides[externalPath].Disabled {
|
||||
t.Fatalf("expected disabled override for %s", externalPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyActionAdoptRejectsInvalidManagedName(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
externalPath := pathJoin("/data/.agents/skills", "escape", "SKILL.md")
|
||||
client.listings["/data/.agents/skills"] = []*pb.FileEntry{{Path: "escape", IsDir: true}}
|
||||
client.files[externalPath] = "---\nname: ..\ndescription: Escape\n---\n\n# Escape"
|
||||
|
||||
err := ApplyAction(context.Background(), client, ActionRequest{
|
||||
Action: ActionAdopt,
|
||||
TargetPath: externalPath,
|
||||
})
|
||||
if !errors.Is(err, bridge.ErrBadRequest) {
|
||||
t.Fatalf("adopt err = %v, want ErrBadRequest", err)
|
||||
}
|
||||
if _, ok := client.files[pathJoin(ManagedDirPath, "..", "SKILL.md")]; ok {
|
||||
t.Fatalf("unexpected managed write for invalid adopted name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidNameRejectsTraversalPatterns(t *testing.T) {
|
||||
for _, name := range []string{
|
||||
"",
|
||||
".",
|
||||
"..",
|
||||
".hidden",
|
||||
"alpha..beta",
|
||||
"../escape",
|
||||
"alpha/../beta",
|
||||
} {
|
||||
if IsValidName(name) {
|
||||
t.Fatalf("IsValidName(%q) = true, want false", name)
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range []string{"alpha", "alpha-beta", "alpha_beta", "alpha.beta"} {
|
||||
if !IsValidName(name) {
|
||||
t.Fatalf("IsValidName(%q) = false, want true", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagedSkillDirForNameRejectsEscapingNames(t *testing.T) {
|
||||
for _, name := range []string{".", "..", ".alpha", "alpha..beta"} {
|
||||
if _, err := ManagedSkillDirForName(name); !errors.Is(err, bridge.ErrBadRequest) {
|
||||
t.Fatalf("ManagedSkillDirForName(%q) err = %v, want ErrBadRequest", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
dirPath, err := ManagedSkillDirForName("alpha.beta")
|
||||
if err != nil {
|
||||
t.Fatalf("ManagedSkillDirForName(valid) returned error: %v", err)
|
||||
}
|
||||
if dirPath != pathJoin(ManagedDirPath, "alpha.beta") {
|
||||
t.Fatalf("ManagedSkillDirForName(valid) = %q, want %q", dirPath, pathJoin(ManagedDirPath, "alpha.beta"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoveryRootsMatchCurrentPolicy(t *testing.T) {
|
||||
roots := DiscoveryRoots()
|
||||
want := []Root{
|
||||
{Path: ManagedDirPath, Kind: SourceKindManaged, Managed: true},
|
||||
{Path: LegacyDirPath, Kind: SourceKindLegacy, Managed: false},
|
||||
{Path: "/data/.agents/skills", Kind: SourceKindCompat, Managed: false},
|
||||
{Path: "/root/.agents/skills", Kind: SourceKindCompat, Managed: false},
|
||||
}
|
||||
if !slices.Equal(roots, want) {
|
||||
t.Fatalf("DiscoveryRoots() = %+v, want %+v", roots, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListScansConfiguredDiscoveryRootsInOrder(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
for _, root := range DiscoveryRoots() {
|
||||
client.listings[root.Path] = nil
|
||||
}
|
||||
client.listings[ManagedDirPath] = []*pb.FileEntry{{Path: "alpha", IsDir: true}}
|
||||
client.files[pathJoin(ManagedDirPath, "alpha", "SKILL.md")] = "---\nname: alpha\ndescription: Alpha\n---\n\n# Alpha"
|
||||
|
||||
items, err := List(context.Background(), client)
|
||||
if err != nil {
|
||||
t.Fatalf("List returned error: %v", err)
|
||||
}
|
||||
if len(items) != 1 || items[0].SourceRoot != ManagedDirPath {
|
||||
t.Fatalf("List() items = %+v, want managed alpha only", items)
|
||||
}
|
||||
|
||||
wantCalls := make([]string, 0, len(DiscoveryRoots()))
|
||||
for _, root := range DiscoveryRoots() {
|
||||
wantCalls = append(wantCalls, root.Path)
|
||||
}
|
||||
if !slices.Equal(client.listCalls, wantCalls) {
|
||||
t.Fatalf("ListDirAll calls = %+v, want %+v", client.listCalls, wantCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainerEnvUsesDataHomeAndXDGDirs(t *testing.T) {
|
||||
env := ContainerEnv()
|
||||
for _, want := range []string{
|
||||
"HOME=/data",
|
||||
"XDG_CONFIG_HOME=/data/.config",
|
||||
"XDG_DATA_HOME=/data/.local/share",
|
||||
"XDG_CACHE_HOME=/data/.cache",
|
||||
} {
|
||||
if !slices.Contains(env, want) {
|
||||
t.Fatalf("env %+v does not contain %q", env, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
listings map[string][]*pb.FileEntry
|
||||
files map[string]string
|
||||
listCalls []string
|
||||
}
|
||||
|
||||
func newFakeClient() *fakeClient {
|
||||
return &fakeClient{
|
||||
listings: make(map[string][]*pb.FileEntry),
|
||||
files: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeClient) ListDirAll(_ context.Context, p string, _ bool) ([]*pb.FileEntry, error) {
|
||||
f.listCalls = append(f.listCalls, p)
|
||||
items, ok := f.listings[p]
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (f *fakeClient) ReadRaw(_ context.Context, p string) (io.ReadCloser, error) {
|
||||
content, ok := f.files[p]
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return io.NopCloser(strings.NewReader(content)), nil
|
||||
}
|
||||
|
||||
func (f *fakeClient) WriteRaw(_ context.Context, p string, r io.Reader) (int64, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
f.files[p] = string(data)
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (*fakeClient) Mkdir(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func pathJoin(parts ...string) string {
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
skillset "github.com/memohai/memoh/internal/skills"
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
)
|
||||
|
||||
@@ -243,9 +244,10 @@ func (m *Manager) buildWorkspaceContainerSpec(botID string, gpu WorkspaceGPUConf
|
||||
tzMounts, tzEnv := ctr.TimezoneSpec()
|
||||
mounts = append(mounts, tzMounts...)
|
||||
|
||||
env := make([]string, 0, len(tzEnv)+1)
|
||||
env := make([]string, 0, len(tzEnv)+1+len(skillset.ContainerEnv()))
|
||||
env = append(env, tzEnv...)
|
||||
env = append(env, "BRIDGE_SOCKET_PATH=/run/memoh/bridge.sock")
|
||||
env = append(env, skillset.ContainerEnv()...)
|
||||
|
||||
return ctr.ContainerSpec{
|
||||
Cmd: []string{"/opt/memoh/bridge"},
|
||||
|
||||
Reference in New Issue
Block a user