From 9f10033f63b9a87f5ff956fc64be492e32d88ec3 Mon Sep 17 00:00:00 2001 From: Acbox Date: Thu, 16 Apr 2026 16:44:29 +0800 Subject: [PATCH] feat(agent): include platform self identities in system prompts Give bots their known per-channel account handles in the system prompt so they can reason about platform-specific self references consistently. Reuse persisted channel self_identity data across chat, discuss, schedule, heartbeat, and subagent prompts. --- cmd/agent/app.go | 3 +- internal/agent/prompt.go | 27 ++- internal/agent/prompt_test.go | 43 ++++ internal/agent/prompts/_identities.md | 1 + internal/agent/prompts/system_chat.md | 2 + internal/agent/prompts/system_discuss.md | 2 + internal/agent/prompts/system_heartbeat.md | 2 + internal/agent/prompts/system_schedule.md | 2 + internal/agent/prompts/system_subagent.md | 2 + internal/channel/service.go | 27 +++ .../conversation/flow/platform_identity.go | 207 ++++++++++++++++++ .../flow/platform_identity_test.go | 73 ++++++ internal/conversation/flow/resolver.go | 37 +++- 13 files changed, 409 insertions(+), 19 deletions(-) create mode 100644 internal/agent/prompt_test.go create mode 100644 internal/agent/prompts/_identities.md create mode 100644 internal/conversation/flow/platform_identity.go create mode 100644 internal/conversation/flow/platform_identity_test.go diff --git a/cmd/agent/app.go b/cmd/agent/app.go index 4330ba90..315263df 100644 --- a/cmd/agent/app.go +++ b/cmd/agent/app.go @@ -257,11 +257,12 @@ func injectToolProviders(a *agentpkg.Agent, msgService *message.DBService, provi } } -func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver { +func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, channelStore *channel.Store, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver { resolver := flow.NewResolver(log, modelsService, queries, chatService, msgService, settingsService, accountService, a, rc.TimezoneLocation, 120*time.Second) resolver.SetMemoryRegistry(memoryRegistry) resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler}) resolver.SetGatewayAssetLoader(&gatewayAssetLoaderAdapter{media: mediaService}) + resolver.SetChannelStore(channelStore) resolver.SetRouteService(routeService) resolver.SetSessionService(sessionService) resolver.SetEventPublisher(eventHub) diff --git a/internal/agent/prompt.go b/internal/agent/prompt.go index 60c8b422..01e67ebc 100644 --- a/internal/agent/prompt.go +++ b/internal/agent/prompt.go @@ -47,6 +47,7 @@ func init() { "_memory": mustReadPrompt("prompts/_memory.md"), "_tools": mustReadPrompt("prompts/_tools.md"), "_contacts": mustReadPrompt("prompts/_contacts.md"), + "_identities": mustReadPrompt("prompts/_identities.md"), "_schedule_task": mustReadPrompt("prompts/_schedule_task.md"), "_subagent": mustReadPrompt("prompts/_subagent.md"), } @@ -144,23 +145,25 @@ func GenerateSystemPrompt(params SystemPromptParams) string { tmpl := selectSystemTemplate(params.SessionType) return render(tmpl, map[string]string{ - "home": home, - "currentTime": now.Format(time.RFC3339), - "timezone": timezoneName, - "basicTools": strings.Join(basicTools, "\n"), - "skillsSection": skillsSection, - "fileSections": fileSections, + "home": home, + "currentTime": now.Format(time.RFC3339), + "timezone": timezoneName, + "basicTools": strings.Join(basicTools, "\n"), + "skillsSection": skillsSection, + "platformIdentitiesSection": strings.TrimSpace(params.PlatformIdentitiesSection), + "fileSections": fileSections, }) } // SystemPromptParams holds all inputs for system prompt generation. type SystemPromptParams struct { - SessionType string - Skills []SkillEntry - Files []SystemFile - Now time.Time - Timezone string - SupportsImageInput bool + SessionType string + Skills []SkillEntry + Files []SystemFile + Now time.Time + Timezone string + SupportsImageInput bool + PlatformIdentitiesSection string } // GenerateSchedulePrompt builds the user message for a scheduled task trigger. diff --git a/internal/agent/prompt_test.go b/internal/agent/prompt_test.go new file mode 100644 index 00000000..50cd3301 --- /dev/null +++ b/internal/agent/prompt_test.go @@ -0,0 +1,43 @@ +package agent + +import ( + "strings" + "testing" + "time" +) + +func TestGenerateSystemPromptIncludesPlatformIdentitiesInChat(t *testing.T) { + t.Parallel() + + prompt := GenerateSystemPrompt(SystemPromptParams{ + SessionType: "chat", + Now: time.Unix(1, 0).UTC(), + Timezone: "UTC", + PlatformIdentitiesSection: "## Platform Identities\n\n", + }) + + if !strings.Contains(prompt, "## Platform Identities") { + t.Fatalf("expected platform identities heading in prompt") + } + if !strings.Contains(prompt, ``) { + t.Fatalf("expected platform identity XML in prompt") + } +} + +func TestGenerateSystemPromptIncludesPlatformIdentitiesInDiscuss(t *testing.T) { + t.Parallel() + + prompt := GenerateSystemPrompt(SystemPromptParams{ + SessionType: "discuss", + Now: time.Unix(1, 0).UTC(), + Timezone: "UTC", + PlatformIdentitiesSection: "## Platform Identities\n\n", + }) + + if !strings.Contains(prompt, "## Platform Identities") { + t.Fatalf("expected platform identities heading in discuss prompt") + } + if !strings.Contains(prompt, ``) { + t.Fatalf("expected platform identity XML in discuss prompt") + } +} diff --git a/internal/agent/prompts/_identities.md b/internal/agent/prompts/_identities.md new file mode 100644 index 00000000..0f84de12 --- /dev/null +++ b/internal/agent/prompts/_identities.md @@ -0,0 +1 @@ +{{platformIdentitiesSection}} diff --git a/internal/agent/prompts/system_chat.md b/internal/agent/prompts/system_chat.md index 4e12bc20..0626e927 100644 --- a/internal/agent/prompts/system_chat.md +++ b/internal/agent/prompts/system_chat.md @@ -43,6 +43,8 @@ You are in **chat mode** — your text output IS your reply. Whatever you write {{include:_contacts}} +{{include:_identities}} + ## Message Format User messages are wrapped in `` XML tags with metadata attributes: diff --git a/internal/agent/prompts/system_discuss.md b/internal/agent/prompts/system_discuss.md index f4263ac1..dfa170cc 100644 --- a/internal/agent/prompts/system_discuss.md +++ b/internal/agent/prompts/system_discuss.md @@ -60,6 +60,8 @@ Not every message needs a response. Staying silent is valid and often appropriat {{include:_contacts}} +{{include:_identities}} + ## Message Format Chat history appears as XML in your conversation. Each message looks like: diff --git a/internal/agent/prompts/system_heartbeat.md b/internal/agent/prompts/system_heartbeat.md index ed2a15cf..d59a9adf 100644 --- a/internal/agent/prompts/system_heartbeat.md +++ b/internal/agent/prompts/system_heartbeat.md @@ -20,6 +20,8 @@ You are in **heartbeat mode** — a periodic system-triggered check. There is no {{include:_contacts}} +{{include:_identities}} + ## The HEARTBEAT_OK Contract - If nothing needs attention, reply with exactly `HEARTBEAT_OK`. diff --git a/internal/agent/prompts/system_schedule.md b/internal/agent/prompts/system_schedule.md index d9c4d5b6..52bc040b 100644 --- a/internal/agent/prompts/system_schedule.md +++ b/internal/agent/prompts/system_schedule.md @@ -20,6 +20,8 @@ You are in **schedule mode** — executing a scheduled task. There is no active {{include:_contacts}} +{{include:_identities}} + ## How to Deliver Results Use `send` to deliver results to the intended channel — there is no active conversation to reply to. Use `get_contacts` to find the right target. diff --git a/internal/agent/prompts/system_subagent.md b/internal/agent/prompts/system_subagent.md index edd11fab..bba109d7 100644 --- a/internal/agent/prompts/system_subagent.md +++ b/internal/agent/prompts/system_subagent.md @@ -18,3 +18,5 @@ You have access to: - Do NOT create schedules or manage memory - Keep private data private - Don't run destructive commands without necessity + +{{include:_identities}} diff --git a/internal/channel/service.go b/internal/channel/service.go index 1ea8248f..da6328a0 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "sort" "strings" "time" @@ -237,6 +238,32 @@ func (s *Store) ResolveEffectiveConfig(ctx context.Context, botID string, channe return ChannelConfig{}, fmt.Errorf("%w", ErrChannelConfigNotFound) } +// ListBotConfigs returns all registered channel configs for a bot. +// Missing configs are skipped so callers can enumerate platform state without +// knowing which integrations are currently configured. +func (s *Store) ListBotConfigs(ctx context.Context, botID string) ([]ChannelConfig, error) { + if strings.TrimSpace(botID) == "" { + return nil, errors.New("bot id is required") + } + types := s.registry.Types() + sort.Slice(types, func(i, j int) bool { + return strings.Compare(types[i].String(), types[j].String()) < 0 + }) + + items := make([]ChannelConfig, 0, len(types)) + for _, channelType := range types { + cfg, err := s.ResolveEffectiveConfig(ctx, botID, channelType) + if err != nil { + if errors.Is(err, ErrChannelConfigNotFound) { + continue + } + return nil, err + } + items = append(items, cfg) + } + return items, nil +} + // ListConfigsByType returns all channel configurations of the given type. func (s *Store) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { if s.queries == nil { diff --git a/internal/conversation/flow/platform_identity.go b/internal/conversation/flow/platform_identity.go new file mode 100644 index 00000000..d9c11d7d --- /dev/null +++ b/internal/conversation/flow/platform_identity.go @@ -0,0 +1,207 @@ +package flow + +import ( + "encoding/json" + "fmt" + "sort" + "strconv" + "strings" + "unicode" + "unicode/utf8" + + "github.com/memohai/memoh/internal/channel" +) + +const platformIdentitiesIntro = "## Platform Identities\n\nThese XML tags describe your own known account identities across connected platforms.\n" + +type identityAttr struct { + Name string + Value string +} + +func buildPlatformIdentitiesSection(configs []channel.ChannelConfig) string { + xmlBlock := buildPlatformIdentitiesXML(configs) + if xmlBlock == "" { + return "" + } + return platformIdentitiesIntro + "\n" + xmlBlock +} + +func buildPlatformIdentitiesXML(configs []channel.ChannelConfig) string { + if len(configs) == 0 { + return "" + } + sorted := make([]channel.ChannelConfig, len(configs)) + copy(sorted, configs) + sort.Slice(sorted, func(i, j int) bool { + left := sorted[i] + right := sorted[j] + if cmp := strings.Compare(left.ChannelType.String(), right.ChannelType.String()); cmp != 0 { + return cmp < 0 + } + if cmp := strings.Compare(left.ExternalIdentity, right.ExternalIdentity); cmp != 0 { + return cmp < 0 + } + return strings.Compare(left.ID, right.ID) < 0 + }) + + lines := make([]string, 0, len(sorted)) + for _, cfg := range sorted { + line := buildPlatformIdentityLine(cfg) + if line == "" { + continue + } + lines = append(lines, line) + } + return strings.Join(lines, "\n") +} + +func buildPlatformIdentityLine(cfg channel.ChannelConfig) string { + channelName := strings.TrimSpace(cfg.ChannelType.String()) + if channelName == "" { + return "" + } + attrs := []identityAttr{{ + Name: "channel", + Value: channelName, + }} + + keys := make([]string, 0, len(cfg.SelfIdentity)) + for key := range cfg.SelfIdentity { + keys = append(keys, key) + } + sort.Strings(keys) + + seen := map[string]struct{}{ + "channel": {}, + } + for _, key := range keys { + name, ok := normalizeIdentityAttrName(key) + if !ok { + continue + } + if _, exists := seen[name]; exists { + continue + } + value, ok := stringifyIdentityAttrValue(cfg.SelfIdentity[key]) + if !ok { + continue + } + if name == "username" { + value = normalizeIdentityUsername(value) + if strings.TrimSpace(value) == "" { + continue + } + } + attrs = append(attrs, identityAttr{Name: name, Value: value}) + seen[name] = struct{}{} + } + + if externalIdentity := strings.TrimSpace(cfg.ExternalIdentity); externalIdentity != "" { + if _, exists := seen["external_identity"]; !exists { + attrs = append(attrs, identityAttr{Name: "external_identity", Value: externalIdentity}) + } + } + + if len(attrs) == 1 { + return "" + } + + var sb strings.Builder + sb.WriteString("") + return sb.String() +} + +func normalizeIdentityAttrName(name string) (string, bool) { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return "", false + } + + var sb strings.Builder + for _, r := range trimmed { + switch { + case unicode.IsLetter(r), unicode.IsDigit(r), r == '_', r == '-', r == '.': + sb.WriteRune(r) + default: + sb.WriteByte('_') + } + } + + normalized := strings.Trim(sb.String(), "_-.") + if normalized == "" { + return "", false + } + + first, _ := utf8.DecodeRuneInString(normalized) + if !unicode.IsLetter(first) && first != '_' { + normalized = "attr_" + normalized + } + if strings.HasPrefix(strings.ToLower(normalized), "xml") { + normalized = "attr_" + normalized + } + return normalized, true +} + +func stringifyIdentityAttrValue(value any) (string, bool) { + switch v := value.(type) { + case nil: + return "", false + case string: + s := strings.TrimSpace(v) + return s, s != "" + case json.Number: + s := strings.TrimSpace(v.String()) + return s, s != "" + case fmt.Stringer: + s := strings.TrimSpace(v.String()) + return s, s != "" + case bool: + return strconv.FormatBool(v), true + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64: + return fmt.Sprint(v), true + default: + data, err := json.Marshal(v) + if err != nil { + s := strings.TrimSpace(fmt.Sprint(v)) + return s, s != "" + } + s := strings.TrimSpace(string(data)) + if s == "" || s == "null" { + return "", false + } + return s, true + } +} + +func normalizeIdentityUsername(username string) string { + trimmed := strings.TrimSpace(username) + if trimmed == "" { + return "" + } + if strings.HasPrefix(trimmed, "@") { + return trimmed + } + return "@" + trimmed +} + +func escapeIdentityAttrValue(value string) string { + replacer := strings.NewReplacer( + "&", "&", + `"`, """, + "<", "<", + ">", ">", + "'", "'", + ) + return replacer.Replace(value) +} diff --git a/internal/conversation/flow/platform_identity_test.go b/internal/conversation/flow/platform_identity_test.go new file mode 100644 index 00000000..7e94a535 --- /dev/null +++ b/internal/conversation/flow/platform_identity_test.go @@ -0,0 +1,73 @@ +package flow + +import ( + "strings" + "testing" + + "github.com/memohai/memoh/internal/channel" +) + +func TestBuildPlatformIdentitiesXML(t *testing.T) { + t.Parallel() + + configs := []channel.ChannelConfig{ + { + ID: "tg-1", + ChannelType: channel.ChannelTypeTelegram, + ExternalIdentity: "12345", + SelfIdentity: map[string]any{ + "user_id": "12345", + "username": "memoh_bot", + }, + }, + { + ID: "discord-1", + ChannelType: channel.ChannelTypeDiscord, + ExternalIdentity: "98765", + SelfIdentity: map[string]any{ + "name": "Memoh & Co", + "username": "@memoh", + }, + }, + } + + got := buildPlatformIdentitiesXML(configs) + want := strings.Join([]string{ + ``, + ``, + }, "\n") + if got != want { + t.Fatalf("unexpected XML:\nwant:\n%s\n\ngot:\n%s", want, got) + } +} + +func TestBuildPlatformIdentityLineNormalizesAttrs(t *testing.T) { + t.Parallel() + + got := buildPlatformIdentityLine(channel.ChannelConfig{ + ChannelType: channel.ChannelTypeTelegram, + SelfIdentity: map[string]any{ + "123id": 7, + "display name": `Memoh `, + "username": "memoh", + "xml_name": "reserved", + }, + }) + + want := `` + if got != want { + t.Fatalf("unexpected identity line:\nwant: %s\ngot: %s", want, got) + } +} + +func TestBuildPlatformIdentitiesSectionSkipsEmptyConfigs(t *testing.T) { + t.Parallel() + + got := buildPlatformIdentitiesSection([]channel.ChannelConfig{{ + ID: "local-1", + ChannelType: channel.ChannelTypeLocal, + }}) + if got != "" { + t.Fatalf("expected empty section, got %q", got) + } +} diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 2c919ba2..f8db30bd 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -22,6 +22,7 @@ import ( "github.com/memohai/memoh/internal/accounts" agentpkg "github.com/memohai/memoh/internal/agent" "github.com/memohai/memoh/internal/agent/background" + "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/compaction" "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/db/sqlc" @@ -63,6 +64,10 @@ type gatewayAssetLoader interface { OpenForGateway(ctx context.Context, botID, contentHash string) (reader io.ReadCloser, mime string, err error) } +type botChannelConfigReader interface { + ListBotConfigs(ctx context.Context, botID string) ([]channel.ChannelConfig, error) +} + // Resolver orchestrates chat with the internal agent. type Resolver struct { agent *agentpkg.Agent @@ -79,6 +84,7 @@ type Resolver struct { eventPublisher messageevent.Publisher skillLoader SkillLoader assetLoader gatewayAssetLoader + channelStore botChannelConfigReader pipeline *pipelinepkg.Pipeline streamHTTPClient *http.Client bgManager *background.Manager @@ -161,6 +167,12 @@ func (r *Resolver) SetGatewayAssetLoader(loader gatewayAssetLoader) { r.assetLoader = loader } +// SetChannelStore configures the bot channel config store used to load +// platform identity metadata for system prompt generation. +func (r *Resolver) SetChannelStore(store botChannelConfigReader) { + r.channelStore = store +} + // SetCompactionService configures the compaction service for context compaction. func (r *Resolver) SetCompactionService(s *compaction.Service) { r.compactionService = s @@ -642,13 +654,26 @@ func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig) if cfg.Identity.TimezoneLocation != nil { now = now.In(cfg.Identity.TimezoneLocation) } + platformIdentitiesSection := "" + if r.channelStore != nil { + channelConfigs, err := r.channelStore.ListBotConfigs(ctx, cfg.Identity.BotID) + if err != nil { + r.logger.Warn("load bot platform identities failed", + slog.String("bot_id", cfg.Identity.BotID), + slog.Any("error", err), + ) + } else { + platformIdentitiesSection = buildPlatformIdentitiesSection(channelConfigs) + } + } cfg.System = agentpkg.GenerateSystemPrompt(agentpkg.SystemPromptParams{ - SessionType: cfg.SessionType, - Skills: cfg.Skills, - Files: files, - Now: now, - Timezone: cfg.Identity.Timezone, - SupportsImageInput: supportsImageInput, + SessionType: cfg.SessionType, + Skills: cfg.Skills, + Files: files, + Now: now, + Timezone: cfg.Identity.Timezone, + SupportsImageInput: supportsImageInput, + PlatformIdentitiesSection: platformIdentitiesSection, }) if cfg.Query != "" {