mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
feat(agent): include platform self identities in system prompts
Give bots their known per-channel account handles in the system prompt so they can reason about platform-specific self references consistently. Reuse persisted channel self_identity data across chat, discuss, schedule, heartbeat, and subagent prompts.
This commit is contained in:
+2
-1
@@ -257,11 +257,12 @@ func injectToolProviders(a *agentpkg.Agent, msgService *message.DBService, provi
|
||||
}
|
||||
}
|
||||
|
||||
func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver {
|
||||
func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, channelStore *channel.Store, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver {
|
||||
resolver := flow.NewResolver(log, modelsService, queries, chatService, msgService, settingsService, accountService, a, rc.TimezoneLocation, 120*time.Second)
|
||||
resolver.SetMemoryRegistry(memoryRegistry)
|
||||
resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler})
|
||||
resolver.SetGatewayAssetLoader(&gatewayAssetLoaderAdapter{media: mediaService})
|
||||
resolver.SetChannelStore(channelStore)
|
||||
resolver.SetRouteService(routeService)
|
||||
resolver.SetSessionService(sessionService)
|
||||
resolver.SetEventPublisher(eventHub)
|
||||
|
||||
+15
-12
@@ -47,6 +47,7 @@ func init() {
|
||||
"_memory": mustReadPrompt("prompts/_memory.md"),
|
||||
"_tools": mustReadPrompt("prompts/_tools.md"),
|
||||
"_contacts": mustReadPrompt("prompts/_contacts.md"),
|
||||
"_identities": mustReadPrompt("prompts/_identities.md"),
|
||||
"_schedule_task": mustReadPrompt("prompts/_schedule_task.md"),
|
||||
"_subagent": mustReadPrompt("prompts/_subagent.md"),
|
||||
}
|
||||
@@ -144,23 +145,25 @@ func GenerateSystemPrompt(params SystemPromptParams) string {
|
||||
tmpl := selectSystemTemplate(params.SessionType)
|
||||
|
||||
return render(tmpl, map[string]string{
|
||||
"home": home,
|
||||
"currentTime": now.Format(time.RFC3339),
|
||||
"timezone": timezoneName,
|
||||
"basicTools": strings.Join(basicTools, "\n"),
|
||||
"skillsSection": skillsSection,
|
||||
"fileSections": fileSections,
|
||||
"home": home,
|
||||
"currentTime": now.Format(time.RFC3339),
|
||||
"timezone": timezoneName,
|
||||
"basicTools": strings.Join(basicTools, "\n"),
|
||||
"skillsSection": skillsSection,
|
||||
"platformIdentitiesSection": strings.TrimSpace(params.PlatformIdentitiesSection),
|
||||
"fileSections": fileSections,
|
||||
})
|
||||
}
|
||||
|
||||
// SystemPromptParams holds all inputs for system prompt generation.
|
||||
type SystemPromptParams struct {
|
||||
SessionType string
|
||||
Skills []SkillEntry
|
||||
Files []SystemFile
|
||||
Now time.Time
|
||||
Timezone string
|
||||
SupportsImageInput bool
|
||||
SessionType string
|
||||
Skills []SkillEntry
|
||||
Files []SystemFile
|
||||
Now time.Time
|
||||
Timezone string
|
||||
SupportsImageInput bool
|
||||
PlatformIdentitiesSection string
|
||||
}
|
||||
|
||||
// GenerateSchedulePrompt builds the user message for a scheduled task trigger.
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateSystemPromptIncludesPlatformIdentitiesInChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prompt := GenerateSystemPrompt(SystemPromptParams{
|
||||
SessionType: "chat",
|
||||
Now: time.Unix(1, 0).UTC(),
|
||||
Timezone: "UTC",
|
||||
PlatformIdentitiesSection: "## Platform Identities\n\n<identity channel=\"telegram\" username=\"@memoh\"/>",
|
||||
})
|
||||
|
||||
if !strings.Contains(prompt, "## Platform Identities") {
|
||||
t.Fatalf("expected platform identities heading in prompt")
|
||||
}
|
||||
if !strings.Contains(prompt, `<identity channel="telegram" username="@memoh"/>`) {
|
||||
t.Fatalf("expected platform identity XML in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSystemPromptIncludesPlatformIdentitiesInDiscuss(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prompt := GenerateSystemPrompt(SystemPromptParams{
|
||||
SessionType: "discuss",
|
||||
Now: time.Unix(1, 0).UTC(),
|
||||
Timezone: "UTC",
|
||||
PlatformIdentitiesSection: "## Platform Identities\n\n<identity channel=\"discord\" username=\"@memoh\"/>",
|
||||
})
|
||||
|
||||
if !strings.Contains(prompt, "## Platform Identities") {
|
||||
t.Fatalf("expected platform identities heading in discuss prompt")
|
||||
}
|
||||
if !strings.Contains(prompt, `<identity channel="discord" username="@memoh"/>`) {
|
||||
t.Fatalf("expected platform identity XML in discuss prompt")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
{{platformIdentitiesSection}}
|
||||
@@ -43,6 +43,8 @@ You are in **chat mode** — your text output IS your reply. Whatever you write
|
||||
|
||||
{{include:_contacts}}
|
||||
|
||||
{{include:_identities}}
|
||||
|
||||
## Message Format
|
||||
|
||||
User messages are wrapped in `<message>` XML tags with metadata attributes:
|
||||
|
||||
@@ -60,6 +60,8 @@ Not every message needs a response. Staying silent is valid and often appropriat
|
||||
|
||||
{{include:_contacts}}
|
||||
|
||||
{{include:_identities}}
|
||||
|
||||
## Message Format
|
||||
|
||||
Chat history appears as XML in your conversation. Each message looks like:
|
||||
|
||||
@@ -20,6 +20,8 @@ You are in **heartbeat mode** — a periodic system-triggered check. There is no
|
||||
|
||||
{{include:_contacts}}
|
||||
|
||||
{{include:_identities}}
|
||||
|
||||
## The HEARTBEAT_OK Contract
|
||||
|
||||
- If nothing needs attention, reply with exactly `HEARTBEAT_OK`.
|
||||
|
||||
@@ -20,6 +20,8 @@ You are in **schedule mode** — executing a scheduled task. There is no active
|
||||
|
||||
{{include:_contacts}}
|
||||
|
||||
{{include:_identities}}
|
||||
|
||||
## How to Deliver Results
|
||||
|
||||
Use `send` to deliver results to the intended channel — there is no active conversation to reply to. Use `get_contacts` to find the right target.
|
||||
|
||||
@@ -18,3 +18,5 @@ You have access to:
|
||||
- Do NOT create schedules or manage memory
|
||||
- Keep private data private
|
||||
- Don't run destructive commands without necessity
|
||||
|
||||
{{include:_identities}}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -237,6 +238,32 @@ func (s *Store) ResolveEffectiveConfig(ctx context.Context, botID string, channe
|
||||
return ChannelConfig{}, fmt.Errorf("%w", ErrChannelConfigNotFound)
|
||||
}
|
||||
|
||||
// ListBotConfigs returns all registered channel configs for a bot.
|
||||
// Missing configs are skipped so callers can enumerate platform state without
|
||||
// knowing which integrations are currently configured.
|
||||
func (s *Store) ListBotConfigs(ctx context.Context, botID string) ([]ChannelConfig, error) {
|
||||
if strings.TrimSpace(botID) == "" {
|
||||
return nil, errors.New("bot id is required")
|
||||
}
|
||||
types := s.registry.Types()
|
||||
sort.Slice(types, func(i, j int) bool {
|
||||
return strings.Compare(types[i].String(), types[j].String()) < 0
|
||||
})
|
||||
|
||||
items := make([]ChannelConfig, 0, len(types))
|
||||
for _, channelType := range types {
|
||||
cfg, err := s.ResolveEffectiveConfig(ctx, botID, channelType)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrChannelConfigNotFound) {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, cfg)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// ListConfigsByType returns all channel configurations of the given type.
|
||||
func (s *Store) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) {
|
||||
if s.queries == nil {
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
const platformIdentitiesIntro = "## Platform Identities\n\nThese XML tags describe your own known account identities across connected platforms.\n"
|
||||
|
||||
type identityAttr struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
func buildPlatformIdentitiesSection(configs []channel.ChannelConfig) string {
|
||||
xmlBlock := buildPlatformIdentitiesXML(configs)
|
||||
if xmlBlock == "" {
|
||||
return ""
|
||||
}
|
||||
return platformIdentitiesIntro + "\n" + xmlBlock
|
||||
}
|
||||
|
||||
func buildPlatformIdentitiesXML(configs []channel.ChannelConfig) string {
|
||||
if len(configs) == 0 {
|
||||
return ""
|
||||
}
|
||||
sorted := make([]channel.ChannelConfig, len(configs))
|
||||
copy(sorted, configs)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
left := sorted[i]
|
||||
right := sorted[j]
|
||||
if cmp := strings.Compare(left.ChannelType.String(), right.ChannelType.String()); cmp != 0 {
|
||||
return cmp < 0
|
||||
}
|
||||
if cmp := strings.Compare(left.ExternalIdentity, right.ExternalIdentity); cmp != 0 {
|
||||
return cmp < 0
|
||||
}
|
||||
return strings.Compare(left.ID, right.ID) < 0
|
||||
})
|
||||
|
||||
lines := make([]string, 0, len(sorted))
|
||||
for _, cfg := range sorted {
|
||||
line := buildPlatformIdentityLine(cfg)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func buildPlatformIdentityLine(cfg channel.ChannelConfig) string {
|
||||
channelName := strings.TrimSpace(cfg.ChannelType.String())
|
||||
if channelName == "" {
|
||||
return ""
|
||||
}
|
||||
attrs := []identityAttr{{
|
||||
Name: "channel",
|
||||
Value: channelName,
|
||||
}}
|
||||
|
||||
keys := make([]string, 0, len(cfg.SelfIdentity))
|
||||
for key := range cfg.SelfIdentity {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
seen := map[string]struct{}{
|
||||
"channel": {},
|
||||
}
|
||||
for _, key := range keys {
|
||||
name, ok := normalizeIdentityAttrName(key)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[name]; exists {
|
||||
continue
|
||||
}
|
||||
value, ok := stringifyIdentityAttrValue(cfg.SelfIdentity[key])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name == "username" {
|
||||
value = normalizeIdentityUsername(value)
|
||||
if strings.TrimSpace(value) == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
attrs = append(attrs, identityAttr{Name: name, Value: value})
|
||||
seen[name] = struct{}{}
|
||||
}
|
||||
|
||||
if externalIdentity := strings.TrimSpace(cfg.ExternalIdentity); externalIdentity != "" {
|
||||
if _, exists := seen["external_identity"]; !exists {
|
||||
attrs = append(attrs, identityAttr{Name: "external_identity", Value: externalIdentity})
|
||||
}
|
||||
}
|
||||
|
||||
if len(attrs) == 1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<identity")
|
||||
for _, attr := range attrs {
|
||||
sb.WriteByte(' ')
|
||||
sb.WriteString(attr.Name)
|
||||
sb.WriteString(`="`)
|
||||
sb.WriteString(escapeIdentityAttrValue(attr.Value))
|
||||
sb.WriteByte('"')
|
||||
}
|
||||
sb.WriteString("/>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func normalizeIdentityAttrName(name string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case unicode.IsLetter(r), unicode.IsDigit(r), r == '_', r == '-', r == '.':
|
||||
sb.WriteRune(r)
|
||||
default:
|
||||
sb.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
normalized := strings.Trim(sb.String(), "_-.")
|
||||
if normalized == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
first, _ := utf8.DecodeRuneInString(normalized)
|
||||
if !unicode.IsLetter(first) && first != '_' {
|
||||
normalized = "attr_" + normalized
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(normalized), "xml") {
|
||||
normalized = "attr_" + normalized
|
||||
}
|
||||
return normalized, true
|
||||
}
|
||||
|
||||
func stringifyIdentityAttrValue(value any) (string, bool) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return "", false
|
||||
case string:
|
||||
s := strings.TrimSpace(v)
|
||||
return s, s != ""
|
||||
case json.Number:
|
||||
s := strings.TrimSpace(v.String())
|
||||
return s, s != ""
|
||||
case fmt.Stringer:
|
||||
s := strings.TrimSpace(v.String())
|
||||
return s, s != ""
|
||||
case bool:
|
||||
return strconv.FormatBool(v), true
|
||||
case int, int8, int16, int32, int64,
|
||||
uint, uint8, uint16, uint32, uint64,
|
||||
float32, float64:
|
||||
return fmt.Sprint(v), true
|
||||
default:
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
s := strings.TrimSpace(fmt.Sprint(v))
|
||||
return s, s != ""
|
||||
}
|
||||
s := strings.TrimSpace(string(data))
|
||||
if s == "" || s == "null" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeIdentityUsername(username string) string {
|
||||
trimmed := strings.TrimSpace(username)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "@") {
|
||||
return trimmed
|
||||
}
|
||||
return "@" + trimmed
|
||||
}
|
||||
|
||||
func escapeIdentityAttrValue(value string) string {
|
||||
replacer := strings.NewReplacer(
|
||||
"&", "&",
|
||||
`"`, """,
|
||||
"<", "<",
|
||||
">", ">",
|
||||
"'", "'",
|
||||
)
|
||||
return replacer.Replace(value)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
)
|
||||
|
||||
func TestBuildPlatformIdentitiesXML(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configs := []channel.ChannelConfig{
|
||||
{
|
||||
ID: "tg-1",
|
||||
ChannelType: channel.ChannelTypeTelegram,
|
||||
ExternalIdentity: "12345",
|
||||
SelfIdentity: map[string]any{
|
||||
"user_id": "12345",
|
||||
"username": "memoh_bot",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "discord-1",
|
||||
ChannelType: channel.ChannelTypeDiscord,
|
||||
ExternalIdentity: "98765",
|
||||
SelfIdentity: map[string]any{
|
||||
"name": "Memoh & Co",
|
||||
"username": "@memoh",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := buildPlatformIdentitiesXML(configs)
|
||||
want := strings.Join([]string{
|
||||
`<identity channel="discord" name="Memoh & Co" username="@memoh" external_identity="98765"/>`,
|
||||
`<identity channel="telegram" user_id="12345" username="@memoh_bot" external_identity="12345"/>`,
|
||||
}, "\n")
|
||||
if got != want {
|
||||
t.Fatalf("unexpected XML:\nwant:\n%s\n\ngot:\n%s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPlatformIdentityLineNormalizesAttrs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := buildPlatformIdentityLine(channel.ChannelConfig{
|
||||
ChannelType: channel.ChannelTypeTelegram,
|
||||
SelfIdentity: map[string]any{
|
||||
"123id": 7,
|
||||
"display name": `Memoh <Bot>`,
|
||||
"username": "memoh",
|
||||
"xml_name": "reserved",
|
||||
},
|
||||
})
|
||||
|
||||
want := `<identity channel="telegram" attr_123id="7" display_name="Memoh <Bot>" username="@memoh" attr_xml_name="reserved"/>`
|
||||
if got != want {
|
||||
t.Fatalf("unexpected identity line:\nwant: %s\ngot: %s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPlatformIdentitiesSectionSkipsEmptyConfigs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := buildPlatformIdentitiesSection([]channel.ChannelConfig{{
|
||||
ID: "local-1",
|
||||
ChannelType: channel.ChannelTypeLocal,
|
||||
}})
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty section, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
agentpkg "github.com/memohai/memoh/internal/agent"
|
||||
"github.com/memohai/memoh/internal/agent/background"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/compaction"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
@@ -63,6 +64,10 @@ type gatewayAssetLoader interface {
|
||||
OpenForGateway(ctx context.Context, botID, contentHash string) (reader io.ReadCloser, mime string, err error)
|
||||
}
|
||||
|
||||
type botChannelConfigReader interface {
|
||||
ListBotConfigs(ctx context.Context, botID string) ([]channel.ChannelConfig, error)
|
||||
}
|
||||
|
||||
// Resolver orchestrates chat with the internal agent.
|
||||
type Resolver struct {
|
||||
agent *agentpkg.Agent
|
||||
@@ -79,6 +84,7 @@ type Resolver struct {
|
||||
eventPublisher messageevent.Publisher
|
||||
skillLoader SkillLoader
|
||||
assetLoader gatewayAssetLoader
|
||||
channelStore botChannelConfigReader
|
||||
pipeline *pipelinepkg.Pipeline
|
||||
streamHTTPClient *http.Client
|
||||
bgManager *background.Manager
|
||||
@@ -161,6 +167,12 @@ func (r *Resolver) SetGatewayAssetLoader(loader gatewayAssetLoader) {
|
||||
r.assetLoader = loader
|
||||
}
|
||||
|
||||
// SetChannelStore configures the bot channel config store used to load
|
||||
// platform identity metadata for system prompt generation.
|
||||
func (r *Resolver) SetChannelStore(store botChannelConfigReader) {
|
||||
r.channelStore = store
|
||||
}
|
||||
|
||||
// SetCompactionService configures the compaction service for context compaction.
|
||||
func (r *Resolver) SetCompactionService(s *compaction.Service) {
|
||||
r.compactionService = s
|
||||
@@ -642,13 +654,26 @@ func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig)
|
||||
if cfg.Identity.TimezoneLocation != nil {
|
||||
now = now.In(cfg.Identity.TimezoneLocation)
|
||||
}
|
||||
platformIdentitiesSection := ""
|
||||
if r.channelStore != nil {
|
||||
channelConfigs, err := r.channelStore.ListBotConfigs(ctx, cfg.Identity.BotID)
|
||||
if err != nil {
|
||||
r.logger.Warn("load bot platform identities failed",
|
||||
slog.String("bot_id", cfg.Identity.BotID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
} else {
|
||||
platformIdentitiesSection = buildPlatformIdentitiesSection(channelConfigs)
|
||||
}
|
||||
}
|
||||
cfg.System = agentpkg.GenerateSystemPrompt(agentpkg.SystemPromptParams{
|
||||
SessionType: cfg.SessionType,
|
||||
Skills: cfg.Skills,
|
||||
Files: files,
|
||||
Now: now,
|
||||
Timezone: cfg.Identity.Timezone,
|
||||
SupportsImageInput: supportsImageInput,
|
||||
SessionType: cfg.SessionType,
|
||||
Skills: cfg.Skills,
|
||||
Files: files,
|
||||
Now: now,
|
||||
Timezone: cfg.Identity.Timezone,
|
||||
SupportsImageInput: supportsImageInput,
|
||||
PlatformIdentitiesSection: platformIdentitiesSection,
|
||||
})
|
||||
|
||||
if cfg.Query != "" {
|
||||
|
||||
Reference in New Issue
Block a user