mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(skills): add effective skill resolution and actions (#377)
* feat(skills): add effective skill resolution and actions * refactor(workspace): normalize skill-related env and prompt * chore(api): regenerate skills OpenAPI and SDK artifacts * feat(web): surface effective skill state in console * test(skills): cover API and runtime effective state * fix(web): show adopt action for discovered skills * fix(web): align skill header and show stateful visibility icon * refactor(web): compact skill metadata on narrow layouts * fix(web): constrain long skill text in cards * refactor(skills): narrow default discovery roots * fix(skills): harden managed skill path validation * feat: add path in the results of `use_skill` --------- Co-authored-by: Acbox <acbox0328@gmail.com>
This commit is contained in:
@@ -0,0 +1,509 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||
)
|
||||
|
||||
const (
|
||||
ManagedDirPath = config.DefaultDataMount + "/skills"
|
||||
LegacyDirPath = config.DefaultDataMount + "/.skills"
|
||||
IndexDirPath = config.DefaultDataMount + "/.memoh/skills"
|
||||
IndexFilePath = IndexDirPath + "/index.json"
|
||||
|
||||
SourceKindManaged = "managed"
|
||||
SourceKindLegacy = "legacy"
|
||||
SourceKindCompat = "compat"
|
||||
|
||||
StateEffective = "effective"
|
||||
StateShadowed = "shadowed"
|
||||
StateDisabled = "disabled"
|
||||
|
||||
ActionAdopt = "adopt"
|
||||
ActionDisable = "disable"
|
||||
ActionEnable = "enable"
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Raw string `json:"raw"`
|
||||
SourcePath string `json:"source_path,omitempty"`
|
||||
SourceRoot string `json:"source_root,omitempty"`
|
||||
SourceKind string `json:"source_kind,omitempty"`
|
||||
Managed bool `json:"managed,omitempty"`
|
||||
State string `json:"state,omitempty"`
|
||||
ShadowedBy string `json:"shadowed_by,omitempty"`
|
||||
}
|
||||
|
||||
type ActionRequest struct {
|
||||
Action string
|
||||
TargetPath string
|
||||
}
|
||||
|
||||
type Parsed struct {
|
||||
Name string
|
||||
Description string
|
||||
Content string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
type Root struct {
|
||||
Path string
|
||||
Kind string
|
||||
Managed bool
|
||||
}
|
||||
|
||||
type fileClient interface {
|
||||
ListDirAll(ctx context.Context, path string, recursive bool) ([]*pb.FileEntry, error)
|
||||
ReadRaw(ctx context.Context, path string) (io.ReadCloser, error)
|
||||
WriteRaw(ctx context.Context, path string, r io.Reader) (int64, error)
|
||||
Mkdir(ctx context.Context, path string) error
|
||||
}
|
||||
|
||||
type indexState struct {
|
||||
Version int `json:"version"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
Overrides map[string]indexOverride `json:"overrides,omitempty"`
|
||||
Items []indexedItem `json:"items,omitempty"`
|
||||
}
|
||||
|
||||
type indexOverride struct {
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
type indexedItem struct {
|
||||
Name string `json:"name"`
|
||||
SourcePath string `json:"source_path"`
|
||||
SourceKind string `json:"source_kind"`
|
||||
Managed bool `json:"managed"`
|
||||
State string `json:"state"`
|
||||
ShadowedBy string `json:"shadowed_by,omitempty"`
|
||||
ContentHash string `json:"content_hash,omitempty"`
|
||||
LastSeenAt string `json:"last_seen_at,omitempty"`
|
||||
}
|
||||
|
||||
func ManagedDir() string {
|
||||
return ManagedDirPath
|
||||
}
|
||||
|
||||
func ManagedSkillDirForName(name string) (string, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if !IsValidName(name) {
|
||||
return "", bridge.ErrBadRequest
|
||||
}
|
||||
|
||||
dirPath := path.Clean(path.Join(ManagedDirPath, name))
|
||||
if dirPath == ManagedDirPath || !strings.HasPrefix(dirPath, ManagedDirPath+"/") {
|
||||
return "", bridge.ErrBadRequest
|
||||
}
|
||||
return dirPath, nil
|
||||
}
|
||||
|
||||
func ContainerEnv() []string {
|
||||
return []string{
|
||||
"HOME=" + config.DefaultDataMount,
|
||||
"XDG_CONFIG_HOME=" + path.Join(config.DefaultDataMount, ".config"),
|
||||
"XDG_DATA_HOME=" + path.Join(config.DefaultDataMount, ".local", "share"),
|
||||
"XDG_CACHE_HOME=" + path.Join(config.DefaultDataMount, ".cache"),
|
||||
}
|
||||
}
|
||||
|
||||
func DiscoveryRoots() []Root {
|
||||
return []Root{
|
||||
{Path: ManagedDirPath, Kind: SourceKindManaged, Managed: true},
|
||||
{Path: LegacyDirPath, Kind: SourceKindLegacy, Managed: false},
|
||||
{Path: path.Join(config.DefaultDataMount, ".agents", "skills"), Kind: SourceKindCompat, Managed: false},
|
||||
{Path: path.Join("/root", ".agents", "skills"), Kind: SourceKindCompat, Managed: false},
|
||||
}
|
||||
}
|
||||
|
||||
func List(ctx context.Context, client fileClient) ([]Entry, error) {
|
||||
idx := readIndex(ctx, client)
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
resolved := resolve(items, idx.Overrides)
|
||||
writeIndex(ctx, client, idx.withItems(resolved))
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func LoadEffective(ctx context.Context, client fileClient) ([]Entry, error) {
|
||||
items, err := List(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]Entry, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.State == StateEffective {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func ApplyAction(ctx context.Context, client fileClient, req ActionRequest) error {
|
||||
targetPath := strings.TrimSpace(req.TargetPath)
|
||||
if targetPath == "" {
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
|
||||
switch strings.TrimSpace(req.Action) {
|
||||
case ActionDisable:
|
||||
idx := readIndex(ctx, client)
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
if !containsSourcePath(items, targetPath) {
|
||||
return bridge.ErrNotFound
|
||||
}
|
||||
if idx.Overrides == nil {
|
||||
idx.Overrides = make(map[string]indexOverride)
|
||||
}
|
||||
idx.Overrides[targetPath] = indexOverride{Disabled: true}
|
||||
writeIndex(ctx, client, idx.withItems(resolve(items, idx.Overrides)))
|
||||
return nil
|
||||
case ActionEnable:
|
||||
idx := readIndex(ctx, client)
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
if !containsSourcePath(items, targetPath) {
|
||||
return bridge.ErrNotFound
|
||||
}
|
||||
delete(idx.Overrides, targetPath)
|
||||
writeIndex(ctx, client, idx.withItems(resolve(items, idx.Overrides)))
|
||||
return nil
|
||||
case ActionAdopt:
|
||||
items := scan(ctx, client, DiscoveryRoots())
|
||||
target, ok := findBySourcePath(items, targetPath)
|
||||
if !ok {
|
||||
return bridge.ErrNotFound
|
||||
}
|
||||
if target.Managed {
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
for _, item := range items {
|
||||
if item.Name == target.Name && item.Managed {
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
}
|
||||
dirPath, err := ManagedSkillDirForName(target.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.Mkdir(ctx, dirPath); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := client.WriteRaw(ctx, path.Join(dirPath, "SKILL.md"), strings.NewReader(target.Raw)); err != nil {
|
||||
return err
|
||||
}
|
||||
idx := readIndex(ctx, client)
|
||||
writeIndex(ctx, client, idx.withItems(resolve(scan(ctx, client, DiscoveryRoots()), idx.Overrides)))
|
||||
return nil
|
||||
default:
|
||||
return bridge.ErrBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
func ParseFile(raw string, fallbackName string) Parsed {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
result := Parsed{
|
||||
Name: strings.TrimSpace(fallbackName),
|
||||
Content: trimmed,
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "---") {
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
|
||||
rest := trimmed[3:]
|
||||
rest = strings.TrimLeft(rest, " \t")
|
||||
if len(rest) > 0 && rest[0] == '\n' {
|
||||
rest = rest[1:]
|
||||
} else if len(rest) > 1 && rest[0] == '\r' && rest[1] == '\n' {
|
||||
rest = rest[2:]
|
||||
}
|
||||
closingIdx := strings.Index(rest, "\n---")
|
||||
if closingIdx < 0 {
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
|
||||
frontmatterRaw := rest[:closingIdx]
|
||||
body := rest[closingIdx+4:]
|
||||
body = strings.TrimLeft(body, "\r\n")
|
||||
result.Content = body
|
||||
|
||||
var fm struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
}
|
||||
if err := yaml.Unmarshal([]byte(frontmatterRaw), &fm); err != nil {
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
if strings.TrimSpace(fm.Name) != "" {
|
||||
result.Name = strings.TrimSpace(fm.Name)
|
||||
}
|
||||
result.Description = strings.TrimSpace(fm.Description)
|
||||
result.Metadata = fm.Metadata
|
||||
return normalizeParsed(result)
|
||||
}
|
||||
|
||||
func IsValidName(name string) bool {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if name == "." || name == ".." {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(name, ".") || strings.Contains(name, "..") {
|
||||
return false
|
||||
}
|
||||
for _, r := range name {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeParsed(skill Parsed) Parsed {
|
||||
if strings.TrimSpace(skill.Name) == "" {
|
||||
skill.Name = "default"
|
||||
}
|
||||
skill.Name = strings.TrimSpace(skill.Name)
|
||||
skill.Description = strings.TrimSpace(skill.Description)
|
||||
skill.Content = strings.TrimSpace(skill.Content)
|
||||
if skill.Description == "" {
|
||||
skill.Description = skill.Name
|
||||
}
|
||||
if skill.Content == "" {
|
||||
skill.Content = skill.Description
|
||||
}
|
||||
return skill
|
||||
}
|
||||
|
||||
func scan(ctx context.Context, client fileClient, roots []Root) []Entry {
|
||||
items := make([]Entry, 0, 16)
|
||||
for _, root := range roots {
|
||||
entries, err := client.ListDirAll(ctx, root.Path, false)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
slices.SortFunc(entries, func(a, b *pb.FileEntry) int {
|
||||
return strings.Compare(a.GetPath(), b.GetPath())
|
||||
})
|
||||
for _, entry := range entries {
|
||||
if !entry.GetIsDir() {
|
||||
if path.Base(entry.GetPath()) != "SKILL.md" {
|
||||
continue
|
||||
}
|
||||
filePath := path.Join(root.Path, "SKILL.md")
|
||||
raw, err := readRawFile(ctx, client, filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
parsed := ParseFile(raw, "default")
|
||||
items = append(items, entryFromParsed(parsed, raw, root, filePath))
|
||||
continue
|
||||
}
|
||||
|
||||
name := path.Base(entry.GetPath())
|
||||
if name == "" || name == "." {
|
||||
continue
|
||||
}
|
||||
filePath := path.Join(root.Path, name, "SKILL.md")
|
||||
raw, err := readRawFile(ctx, client, filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
parsed := ParseFile(raw, name)
|
||||
items = append(items, entryFromParsed(parsed, raw, root, filePath))
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func resolve(items []Entry, overrides map[string]indexOverride) []Entry {
|
||||
byName := make(map[string][]Entry, len(items))
|
||||
for _, item := range items {
|
||||
byName[item.Name] = append(byName[item.Name], item)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(byName))
|
||||
for name := range byName {
|
||||
names = append(names, name)
|
||||
}
|
||||
slices.Sort(names)
|
||||
|
||||
out := make([]Entry, 0, len(items))
|
||||
for _, name := range names {
|
||||
group := byName[name]
|
||||
var effectivePath string
|
||||
for i := range group {
|
||||
if overrides[group[i].SourcePath].Disabled {
|
||||
group[i].State = StateDisabled
|
||||
out = append(out, group[i])
|
||||
continue
|
||||
}
|
||||
if effectivePath == "" {
|
||||
group[i].State = StateEffective
|
||||
effectivePath = group[i].SourcePath
|
||||
out = append(out, group[i])
|
||||
continue
|
||||
}
|
||||
group[i].State = StateShadowed
|
||||
group[i].ShadowedBy = effectivePath
|
||||
out = append(out, group[i])
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(out, func(a, b Entry) int {
|
||||
if cmp := strings.Compare(a.Name, b.Name); cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
if cmp := stateRank(a.State) - stateRank(b.State); cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
if a.Managed != b.Managed {
|
||||
if a.Managed {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.SourcePath, b.SourcePath)
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func stateRank(state string) int {
|
||||
switch state {
|
||||
case StateEffective:
|
||||
return 0
|
||||
case StateShadowed:
|
||||
return 1
|
||||
case StateDisabled:
|
||||
return 2
|
||||
default:
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
func entryFromParsed(parsed Parsed, raw string, root Root, sourcePath string) Entry {
|
||||
return Entry{
|
||||
Name: parsed.Name,
|
||||
Description: parsed.Description,
|
||||
Content: parsed.Content,
|
||||
Metadata: parsed.Metadata,
|
||||
Raw: raw,
|
||||
SourcePath: sourcePath,
|
||||
SourceRoot: root.Path,
|
||||
SourceKind: root.Kind,
|
||||
Managed: root.Managed,
|
||||
}
|
||||
}
|
||||
|
||||
func readRawFile(ctx context.Context, client fileClient, filePath string) (string, error) {
|
||||
rc, err := client.ReadRaw(ctx, filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = rc.Close() }()
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func readIndex(ctx context.Context, client fileClient) indexState {
|
||||
rc, err := client.ReadRaw(ctx, IndexFilePath)
|
||||
if err != nil {
|
||||
return indexState{Version: 1, Overrides: make(map[string]indexOverride)}
|
||||
}
|
||||
defer func() { _ = rc.Close() }()
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
if err != nil || len(data) == 0 {
|
||||
return indexState{Version: 1, Overrides: make(map[string]indexOverride)}
|
||||
}
|
||||
|
||||
var idx indexState
|
||||
if err := json.Unmarshal(data, &idx); err != nil {
|
||||
return indexState{Version: 1, Overrides: make(map[string]indexOverride)}
|
||||
}
|
||||
if idx.Version == 0 {
|
||||
idx.Version = 1
|
||||
}
|
||||
if idx.Overrides == nil {
|
||||
idx.Overrides = make(map[string]indexOverride)
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func writeIndex(ctx context.Context, client fileClient, idx indexState) {
|
||||
if err := client.Mkdir(ctx, IndexDirPath); err != nil {
|
||||
return
|
||||
}
|
||||
data, err := json.MarshalIndent(idx, "", " ")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = client.WriteRaw(ctx, IndexFilePath, strings.NewReader(string(data)))
|
||||
}
|
||||
|
||||
func (i indexState) withItems(items []Entry) indexState {
|
||||
if i.Version == 0 {
|
||||
i.Version = 1
|
||||
}
|
||||
if i.Overrides == nil {
|
||||
i.Overrides = make(map[string]indexOverride)
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
i.UpdatedAt = now
|
||||
i.Items = make([]indexedItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
sum := sha256.Sum256([]byte(item.Raw))
|
||||
i.Items = append(i.Items, indexedItem{
|
||||
Name: item.Name,
|
||||
SourcePath: item.SourcePath,
|
||||
SourceKind: item.SourceKind,
|
||||
Managed: item.Managed,
|
||||
State: item.State,
|
||||
ShadowedBy: item.ShadowedBy,
|
||||
ContentHash: hex.EncodeToString(sum[:]),
|
||||
LastSeenAt: now,
|
||||
})
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func containsSourcePath(items []Entry, target string) bool {
|
||||
_, ok := findBySourcePath(items, target)
|
||||
return ok
|
||||
}
|
||||
|
||||
func findBySourcePath(items []Entry, target string) (Entry, bool) {
|
||||
for _, item := range items {
|
||||
if item.SourcePath == target {
|
||||
return item, true
|
||||
}
|
||||
}
|
||||
return Entry{}, false
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/workspace/bridge"
|
||||
pb "github.com/memohai/memoh/internal/workspace/bridgepb"
|
||||
)
|
||||
|
||||
func TestParseFileFallbacks(t *testing.T) {
|
||||
raw := "# Use this skill\n\nDo something useful."
|
||||
got := ParseFile(raw, "plain-skill")
|
||||
|
||||
if got.Name != "plain-skill" {
|
||||
t.Fatalf("expected name plain-skill, got %q", got.Name)
|
||||
}
|
||||
if got.Description != "plain-skill" {
|
||||
t.Fatalf("expected description plain-skill, got %q", got.Description)
|
||||
}
|
||||
if got.Content != raw {
|
||||
t.Fatalf("expected content to keep original markdown, got %q", got.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSupportsDisabledFallbackAndShadowing(t *testing.T) {
|
||||
items := []Entry{
|
||||
{Name: "alpha", SourcePath: "/data/skills/alpha/SKILL.md", Managed: true, SourceKind: SourceKindManaged},
|
||||
{Name: "alpha", SourcePath: "/data/.agents/skills/alpha/SKILL.md", SourceKind: SourceKindCompat},
|
||||
{Name: "beta", SourcePath: "/data/.agents/skills/beta/SKILL.md", SourceKind: SourceKindCompat},
|
||||
}
|
||||
|
||||
resolved := resolve(items, map[string]indexOverride{
|
||||
"/data/skills/alpha/SKILL.md": {Disabled: true},
|
||||
})
|
||||
|
||||
managedAlpha, ok := findBySourcePath(resolved, "/data/skills/alpha/SKILL.md")
|
||||
if !ok {
|
||||
t.Fatalf("managed alpha not found in resolved items")
|
||||
}
|
||||
if managedAlpha.State != StateDisabled {
|
||||
t.Fatalf("managed alpha state = %q, want disabled", managedAlpha.State)
|
||||
}
|
||||
compatAlpha, ok := findBySourcePath(resolved, "/data/.agents/skills/alpha/SKILL.md")
|
||||
if !ok {
|
||||
t.Fatalf("compat alpha not found in resolved items")
|
||||
}
|
||||
if compatAlpha.State != StateEffective {
|
||||
t.Fatalf("compat alpha state = %q, want effective", compatAlpha.State)
|
||||
}
|
||||
beta, ok := findBySourcePath(resolved, "/data/.agents/skills/beta/SKILL.md")
|
||||
if !ok {
|
||||
t.Fatalf("beta not found in resolved items")
|
||||
}
|
||||
if beta.State != StateEffective {
|
||||
t.Fatalf("beta state = %q, want effective", beta.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListReadsFullRawContentAndWritesIndex(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
client.listings[ManagedDirPath] = []*pb.FileEntry{{Path: "alpha", IsDir: true}}
|
||||
client.files[pathJoin(ManagedDirPath, "alpha", "SKILL.md")] = "---\nname: alpha\ndescription: Alpha\n---\n\n" + strings.Repeat("A", 7000)
|
||||
|
||||
items, err := List(context.Background(), client)
|
||||
if err != nil {
|
||||
t.Fatalf("List returned error: %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 item, got %d", len(items))
|
||||
}
|
||||
if len(items[0].Raw) <= 7000 {
|
||||
t.Fatalf("expected full raw content, got len=%d", len(items[0].Raw))
|
||||
}
|
||||
if _, ok := client.files[IndexFilePath]; !ok {
|
||||
t.Fatalf("expected index file to be written")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyActionAdoptAndDisable(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
externalPath := pathJoin("/data/.agents/skills", "alpha", "SKILL.md")
|
||||
client.listings["/data/.agents/skills"] = []*pb.FileEntry{{Path: "alpha", IsDir: true}}
|
||||
client.files[externalPath] = "---\nname: alpha\ndescription: Alpha\n---\n\n# Alpha"
|
||||
|
||||
if err := ApplyAction(context.Background(), client, ActionRequest{
|
||||
Action: ActionAdopt,
|
||||
TargetPath: externalPath,
|
||||
}); err != nil {
|
||||
t.Fatalf("adopt returned error: %v", err)
|
||||
}
|
||||
if _, ok := client.files[pathJoin(ManagedDirPath, "alpha", "SKILL.md")]; !ok {
|
||||
t.Fatalf("expected managed copy after adopt")
|
||||
}
|
||||
|
||||
if err := ApplyAction(context.Background(), client, ActionRequest{
|
||||
Action: ActionDisable,
|
||||
TargetPath: externalPath,
|
||||
}); err != nil {
|
||||
t.Fatalf("disable returned error: %v", err)
|
||||
}
|
||||
idx := readIndex(context.Background(), client)
|
||||
if !idx.Overrides[externalPath].Disabled {
|
||||
t.Fatalf("expected disabled override for %s", externalPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyActionAdoptRejectsInvalidManagedName(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
externalPath := pathJoin("/data/.agents/skills", "escape", "SKILL.md")
|
||||
client.listings["/data/.agents/skills"] = []*pb.FileEntry{{Path: "escape", IsDir: true}}
|
||||
client.files[externalPath] = "---\nname: ..\ndescription: Escape\n---\n\n# Escape"
|
||||
|
||||
err := ApplyAction(context.Background(), client, ActionRequest{
|
||||
Action: ActionAdopt,
|
||||
TargetPath: externalPath,
|
||||
})
|
||||
if !errors.Is(err, bridge.ErrBadRequest) {
|
||||
t.Fatalf("adopt err = %v, want ErrBadRequest", err)
|
||||
}
|
||||
if _, ok := client.files[pathJoin(ManagedDirPath, "..", "SKILL.md")]; ok {
|
||||
t.Fatalf("unexpected managed write for invalid adopted name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidNameRejectsTraversalPatterns(t *testing.T) {
|
||||
for _, name := range []string{
|
||||
"",
|
||||
".",
|
||||
"..",
|
||||
".hidden",
|
||||
"alpha..beta",
|
||||
"../escape",
|
||||
"alpha/../beta",
|
||||
} {
|
||||
if IsValidName(name) {
|
||||
t.Fatalf("IsValidName(%q) = true, want false", name)
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range []string{"alpha", "alpha-beta", "alpha_beta", "alpha.beta"} {
|
||||
if !IsValidName(name) {
|
||||
t.Fatalf("IsValidName(%q) = false, want true", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagedSkillDirForNameRejectsEscapingNames(t *testing.T) {
|
||||
for _, name := range []string{".", "..", ".alpha", "alpha..beta"} {
|
||||
if _, err := ManagedSkillDirForName(name); !errors.Is(err, bridge.ErrBadRequest) {
|
||||
t.Fatalf("ManagedSkillDirForName(%q) err = %v, want ErrBadRequest", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
dirPath, err := ManagedSkillDirForName("alpha.beta")
|
||||
if err != nil {
|
||||
t.Fatalf("ManagedSkillDirForName(valid) returned error: %v", err)
|
||||
}
|
||||
if dirPath != pathJoin(ManagedDirPath, "alpha.beta") {
|
||||
t.Fatalf("ManagedSkillDirForName(valid) = %q, want %q", dirPath, pathJoin(ManagedDirPath, "alpha.beta"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoveryRootsMatchCurrentPolicy(t *testing.T) {
|
||||
roots := DiscoveryRoots()
|
||||
want := []Root{
|
||||
{Path: ManagedDirPath, Kind: SourceKindManaged, Managed: true},
|
||||
{Path: LegacyDirPath, Kind: SourceKindLegacy, Managed: false},
|
||||
{Path: "/data/.agents/skills", Kind: SourceKindCompat, Managed: false},
|
||||
{Path: "/root/.agents/skills", Kind: SourceKindCompat, Managed: false},
|
||||
}
|
||||
if !slices.Equal(roots, want) {
|
||||
t.Fatalf("DiscoveryRoots() = %+v, want %+v", roots, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListScansConfiguredDiscoveryRootsInOrder(t *testing.T) {
|
||||
client := newFakeClient()
|
||||
for _, root := range DiscoveryRoots() {
|
||||
client.listings[root.Path] = nil
|
||||
}
|
||||
client.listings[ManagedDirPath] = []*pb.FileEntry{{Path: "alpha", IsDir: true}}
|
||||
client.files[pathJoin(ManagedDirPath, "alpha", "SKILL.md")] = "---\nname: alpha\ndescription: Alpha\n---\n\n# Alpha"
|
||||
|
||||
items, err := List(context.Background(), client)
|
||||
if err != nil {
|
||||
t.Fatalf("List returned error: %v", err)
|
||||
}
|
||||
if len(items) != 1 || items[0].SourceRoot != ManagedDirPath {
|
||||
t.Fatalf("List() items = %+v, want managed alpha only", items)
|
||||
}
|
||||
|
||||
wantCalls := make([]string, 0, len(DiscoveryRoots()))
|
||||
for _, root := range DiscoveryRoots() {
|
||||
wantCalls = append(wantCalls, root.Path)
|
||||
}
|
||||
if !slices.Equal(client.listCalls, wantCalls) {
|
||||
t.Fatalf("ListDirAll calls = %+v, want %+v", client.listCalls, wantCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainerEnvUsesDataHomeAndXDGDirs(t *testing.T) {
|
||||
env := ContainerEnv()
|
||||
for _, want := range []string{
|
||||
"HOME=/data",
|
||||
"XDG_CONFIG_HOME=/data/.config",
|
||||
"XDG_DATA_HOME=/data/.local/share",
|
||||
"XDG_CACHE_HOME=/data/.cache",
|
||||
} {
|
||||
if !slices.Contains(env, want) {
|
||||
t.Fatalf("env %+v does not contain %q", env, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
listings map[string][]*pb.FileEntry
|
||||
files map[string]string
|
||||
listCalls []string
|
||||
}
|
||||
|
||||
func newFakeClient() *fakeClient {
|
||||
return &fakeClient{
|
||||
listings: make(map[string][]*pb.FileEntry),
|
||||
files: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeClient) ListDirAll(_ context.Context, p string, _ bool) ([]*pb.FileEntry, error) {
|
||||
f.listCalls = append(f.listCalls, p)
|
||||
items, ok := f.listings[p]
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (f *fakeClient) ReadRaw(_ context.Context, p string) (io.ReadCloser, error) {
|
||||
content, ok := f.files[p]
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return io.NopCloser(strings.NewReader(content)), nil
|
||||
}
|
||||
|
||||
func (f *fakeClient) WriteRaw(_ context.Context, p string, r io.Reader) (int64, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
f.files[p] = string(data)
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (*fakeClient) Mkdir(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func pathJoin(parts ...string) string {
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
Reference in New Issue
Block a user