mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
158 lines
4.0 KiB
Go
158 lines
4.0 KiB
Go
package builtin
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
|
|
adapters "github.com/memohai/memoh/internal/memory/adapters"
|
|
)
|
|
|
|
func makeItems(texts ...string) []adapters.MemoryItem {
|
|
items := make([]adapters.MemoryItem, len(texts))
|
|
for i, text := range texts {
|
|
items[i] = adapters.MemoryItem{
|
|
ID: "id-" + text[:min(len(text), 8)],
|
|
Memory: text,
|
|
Score: float64(len(texts) - i),
|
|
}
|
|
}
|
|
return items
|
|
}
|
|
|
|
func TestPackContext_BasicPacking(t *testing.T) {
|
|
t.Parallel()
|
|
items := makeItems("alpha", "bravo", "charlie", "delta", "echo", "foxtrot")
|
|
cfg := contextPackerConfig{
|
|
TargetItems: 4,
|
|
MaxTotalChars: 2000,
|
|
MinItemChars: 3,
|
|
MaxItemChars: 100,
|
|
OverfetchRatio: 2,
|
|
}
|
|
result := packContext(items, cfg)
|
|
if len(result.Items) != 4 {
|
|
t.Fatalf("expected 4 packed items, got %d", len(result.Items))
|
|
}
|
|
for _, pi := range result.Items {
|
|
if pi.Snippet == "" {
|
|
t.Fatal("expected non-empty snippet")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPackContext_BudgetLimitsItems(t *testing.T) {
|
|
t.Parallel()
|
|
long := strings.Repeat("x", 500)
|
|
items := makeItems(long, long, long, long, long)
|
|
cfg := contextPackerConfig{
|
|
TargetItems: 5,
|
|
MaxTotalChars: 800,
|
|
MinItemChars: 100,
|
|
MaxItemChars: 500,
|
|
OverfetchRatio: 2,
|
|
}
|
|
result := packContext(items, cfg)
|
|
totalChars := 0
|
|
for _, pi := range result.Items {
|
|
totalChars += len([]rune(pi.Snippet))
|
|
}
|
|
if totalChars > cfg.MaxTotalChars+50 {
|
|
t.Fatalf("total chars %d exceeds budget %d by too much", totalChars, cfg.MaxTotalChars)
|
|
}
|
|
}
|
|
|
|
func TestPackContext_CompressesToFitMore(t *testing.T) {
|
|
t.Parallel()
|
|
medium := strings.Repeat("m", 200)
|
|
items := makeItems(medium, medium, medium, medium, medium, medium)
|
|
cfg := contextPackerConfig{
|
|
TargetItems: 6,
|
|
MaxTotalChars: 600,
|
|
MinItemChars: 50,
|
|
MaxItemChars: 200,
|
|
OverfetchRatio: 2,
|
|
}
|
|
result := packContext(items, cfg)
|
|
if len(result.Items) < 3 {
|
|
t.Fatalf("expected at least 3 items after compression, got %d", len(result.Items))
|
|
}
|
|
}
|
|
|
|
func TestPackContext_ShortItemsNotTruncated(t *testing.T) {
|
|
t.Parallel()
|
|
items := makeItems("hi", "yo", "ok")
|
|
cfg := contextPackerConfig{
|
|
TargetItems: 3,
|
|
MaxTotalChars: 1000,
|
|
MinItemChars: 10,
|
|
MaxItemChars: 200,
|
|
OverfetchRatio: 2,
|
|
}
|
|
result := packContext(items, cfg)
|
|
if len(result.Items) != 3 {
|
|
t.Fatalf("expected 3 items, got %d", len(result.Items))
|
|
}
|
|
for _, pi := range result.Items {
|
|
if strings.HasSuffix(pi.Snippet, "...") {
|
|
t.Fatalf("short item should not be truncated: %q", pi.Snippet)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPackContext_EmptyInput(t *testing.T) {
|
|
t.Parallel()
|
|
result := packContext(nil, defaultPackerConfig)
|
|
if len(result.Items) != 0 {
|
|
t.Fatalf("expected 0 items for nil input, got %d", len(result.Items))
|
|
}
|
|
}
|
|
|
|
func TestAntiLostInMiddle_Reordering(t *testing.T) {
|
|
t.Parallel()
|
|
items := []int{1, 2, 3, 4, 5}
|
|
reordered := antiLostInMiddle(items)
|
|
if reordered[0] != 1 {
|
|
t.Fatalf("expected first item to be 1, got %d", reordered[0])
|
|
}
|
|
if reordered[len(reordered)-1] != 2 {
|
|
t.Fatalf("expected last item to be 2, got %d", reordered[len(reordered)-1])
|
|
}
|
|
}
|
|
|
|
func TestAntiLostInMiddle_SmallSlice(t *testing.T) {
|
|
t.Parallel()
|
|
single := antiLostInMiddle([]string{"a"})
|
|
if len(single) != 1 || single[0] != "a" {
|
|
t.Fatalf("unexpected result for single item: %v", single)
|
|
}
|
|
pair := antiLostInMiddle([]string{"a", "b"})
|
|
if len(pair) != 2 {
|
|
t.Fatalf("unexpected result for pair: %v", pair)
|
|
}
|
|
}
|
|
|
|
func TestOverfetchLimit(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := contextPackerConfig{TargetItems: 5, OverfetchRatio: 3}
|
|
if got := overfetchLimit(cfg); got != 15 {
|
|
t.Fatalf("expected 15, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestDeduplicateAndSort(t *testing.T) {
|
|
t.Parallel()
|
|
items := []adapters.MemoryItem{
|
|
{ID: "a", Score: 1.0, Memory: "first"},
|
|
{ID: "b", Score: 3.0, Memory: "second"},
|
|
{ID: "a", Score: 2.0, Memory: "duplicate"},
|
|
{ID: "c", Score: 2.5, Memory: "third"},
|
|
}
|
|
result := deduplicateAndSort(items)
|
|
if len(result) != 3 {
|
|
t.Fatalf("expected 3 items after dedup, got %d", len(result))
|
|
}
|
|
if result[0].ID != "b" {
|
|
t.Fatalf("expected highest score first, got %q", result[0].ID)
|
|
}
|
|
}
|