feat(agent): include platform self identities in system prompts

Give bots their known per-channel account handles in the system prompt so they can reason about platform-specific self references consistently. Reuse persisted channel self_identity data across chat, discuss, schedule, heartbeat, and subagent prompts.
This commit is contained in:
Acbox
2026-04-16 16:44:29 +08:00
parent e0fc2f514e
commit 9f10033f63
13 changed files with 409 additions and 19 deletions
+2 -1
View File
@@ -257,11 +257,12 @@ func injectToolProviders(a *agentpkg.Agent, msgService *message.DBService, provi
}
}
func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver {
func provideChatResolver(log *slog.Logger, a *agentpkg.Agent, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, accountService *accounts.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, memoryRegistry *memprovider.Registry, channelStore *channel.Store, routeService *route.DBService, sessionService *sessionpkg.Service, eventHub *event.Hub, compactionService *compaction.Service, pipeline *pipelinepkg.Pipeline, rc *boot.RuntimeConfig, bgManager *background.Manager) *flow.Resolver {
resolver := flow.NewResolver(log, modelsService, queries, chatService, msgService, settingsService, accountService, a, rc.TimezoneLocation, 120*time.Second)
resolver.SetMemoryRegistry(memoryRegistry)
resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler})
resolver.SetGatewayAssetLoader(&gatewayAssetLoaderAdapter{media: mediaService})
resolver.SetChannelStore(channelStore)
resolver.SetRouteService(routeService)
resolver.SetSessionService(sessionService)
resolver.SetEventPublisher(eventHub)
+15 -12
View File
@@ -47,6 +47,7 @@ func init() {
"_memory": mustReadPrompt("prompts/_memory.md"),
"_tools": mustReadPrompt("prompts/_tools.md"),
"_contacts": mustReadPrompt("prompts/_contacts.md"),
"_identities": mustReadPrompt("prompts/_identities.md"),
"_schedule_task": mustReadPrompt("prompts/_schedule_task.md"),
"_subagent": mustReadPrompt("prompts/_subagent.md"),
}
@@ -144,23 +145,25 @@ func GenerateSystemPrompt(params SystemPromptParams) string {
tmpl := selectSystemTemplate(params.SessionType)
return render(tmpl, map[string]string{
"home": home,
"currentTime": now.Format(time.RFC3339),
"timezone": timezoneName,
"basicTools": strings.Join(basicTools, "\n"),
"skillsSection": skillsSection,
"fileSections": fileSections,
"home": home,
"currentTime": now.Format(time.RFC3339),
"timezone": timezoneName,
"basicTools": strings.Join(basicTools, "\n"),
"skillsSection": skillsSection,
"platformIdentitiesSection": strings.TrimSpace(params.PlatformIdentitiesSection),
"fileSections": fileSections,
})
}
// SystemPromptParams holds all inputs for system prompt generation.
type SystemPromptParams struct {
SessionType string
Skills []SkillEntry
Files []SystemFile
Now time.Time
Timezone string
SupportsImageInput bool
SessionType string
Skills []SkillEntry
Files []SystemFile
Now time.Time
Timezone string
SupportsImageInput bool
PlatformIdentitiesSection string
}
// GenerateSchedulePrompt builds the user message for a scheduled task trigger.
+43
View File
@@ -0,0 +1,43 @@
package agent
import (
"strings"
"testing"
"time"
)
func TestGenerateSystemPromptIncludesPlatformIdentitiesInChat(t *testing.T) {
t.Parallel()
prompt := GenerateSystemPrompt(SystemPromptParams{
SessionType: "chat",
Now: time.Unix(1, 0).UTC(),
Timezone: "UTC",
PlatformIdentitiesSection: "## Platform Identities\n\n<identity channel=\"telegram\" username=\"@memoh\"/>",
})
if !strings.Contains(prompt, "## Platform Identities") {
t.Fatalf("expected platform identities heading in prompt")
}
if !strings.Contains(prompt, `<identity channel="telegram" username="@memoh"/>`) {
t.Fatalf("expected platform identity XML in prompt")
}
}
func TestGenerateSystemPromptIncludesPlatformIdentitiesInDiscuss(t *testing.T) {
t.Parallel()
prompt := GenerateSystemPrompt(SystemPromptParams{
SessionType: "discuss",
Now: time.Unix(1, 0).UTC(),
Timezone: "UTC",
PlatformIdentitiesSection: "## Platform Identities\n\n<identity channel=\"discord\" username=\"@memoh\"/>",
})
if !strings.Contains(prompt, "## Platform Identities") {
t.Fatalf("expected platform identities heading in discuss prompt")
}
if !strings.Contains(prompt, `<identity channel="discord" username="@memoh"/>`) {
t.Fatalf("expected platform identity XML in discuss prompt")
}
}
+1
View File
@@ -0,0 +1 @@
{{platformIdentitiesSection}}
+2
View File
@@ -43,6 +43,8 @@ You are in **chat mode** — your text output IS your reply. Whatever you write
{{include:_contacts}}
{{include:_identities}}
## Message Format
User messages are wrapped in `<message>` XML tags with metadata attributes:
+2
View File
@@ -60,6 +60,8 @@ Not every message needs a response. Staying silent is valid and often appropriat
{{include:_contacts}}
{{include:_identities}}
## Message Format
Chat history appears as XML in your conversation. Each message looks like:
@@ -20,6 +20,8 @@ You are in **heartbeat mode** — a periodic system-triggered check. There is no
{{include:_contacts}}
{{include:_identities}}
## The HEARTBEAT_OK Contract
- If nothing needs attention, reply with exactly `HEARTBEAT_OK`.
@@ -20,6 +20,8 @@ You are in **schedule mode** — executing a scheduled task. There is no active
{{include:_contacts}}
{{include:_identities}}
## How to Deliver Results
Use `send` to deliver results to the intended channel — there is no active conversation to reply to. Use `get_contacts` to find the right target.
@@ -18,3 +18,5 @@ You have access to:
- Do NOT create schedules or manage memory
- Keep private data private
- Don't run destructive commands without necessity
{{include:_identities}}
+27
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sort"
"strings"
"time"
@@ -237,6 +238,32 @@ func (s *Store) ResolveEffectiveConfig(ctx context.Context, botID string, channe
return ChannelConfig{}, fmt.Errorf("%w", ErrChannelConfigNotFound)
}
// ListBotConfigs returns all registered channel configs for a bot.
// Missing configs are skipped so callers can enumerate platform state without
// knowing which integrations are currently configured.
func (s *Store) ListBotConfigs(ctx context.Context, botID string) ([]ChannelConfig, error) {
if strings.TrimSpace(botID) == "" {
return nil, errors.New("bot id is required")
}
types := s.registry.Types()
sort.Slice(types, func(i, j int) bool {
return strings.Compare(types[i].String(), types[j].String()) < 0
})
items := make([]ChannelConfig, 0, len(types))
for _, channelType := range types {
cfg, err := s.ResolveEffectiveConfig(ctx, botID, channelType)
if err != nil {
if errors.Is(err, ErrChannelConfigNotFound) {
continue
}
return nil, err
}
items = append(items, cfg)
}
return items, nil
}
// ListConfigsByType returns all channel configurations of the given type.
func (s *Store) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) {
if s.queries == nil {
@@ -0,0 +1,207 @@
package flow
import (
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/memohai/memoh/internal/channel"
)
const platformIdentitiesIntro = "## Platform Identities\n\nThese XML tags describe your own known account identities across connected platforms.\n"
type identityAttr struct {
Name string
Value string
}
func buildPlatformIdentitiesSection(configs []channel.ChannelConfig) string {
xmlBlock := buildPlatformIdentitiesXML(configs)
if xmlBlock == "" {
return ""
}
return platformIdentitiesIntro + "\n" + xmlBlock
}
func buildPlatformIdentitiesXML(configs []channel.ChannelConfig) string {
if len(configs) == 0 {
return ""
}
sorted := make([]channel.ChannelConfig, len(configs))
copy(sorted, configs)
sort.Slice(sorted, func(i, j int) bool {
left := sorted[i]
right := sorted[j]
if cmp := strings.Compare(left.ChannelType.String(), right.ChannelType.String()); cmp != 0 {
return cmp < 0
}
if cmp := strings.Compare(left.ExternalIdentity, right.ExternalIdentity); cmp != 0 {
return cmp < 0
}
return strings.Compare(left.ID, right.ID) < 0
})
lines := make([]string, 0, len(sorted))
for _, cfg := range sorted {
line := buildPlatformIdentityLine(cfg)
if line == "" {
continue
}
lines = append(lines, line)
}
return strings.Join(lines, "\n")
}
func buildPlatformIdentityLine(cfg channel.ChannelConfig) string {
channelName := strings.TrimSpace(cfg.ChannelType.String())
if channelName == "" {
return ""
}
attrs := []identityAttr{{
Name: "channel",
Value: channelName,
}}
keys := make([]string, 0, len(cfg.SelfIdentity))
for key := range cfg.SelfIdentity {
keys = append(keys, key)
}
sort.Strings(keys)
seen := map[string]struct{}{
"channel": {},
}
for _, key := range keys {
name, ok := normalizeIdentityAttrName(key)
if !ok {
continue
}
if _, exists := seen[name]; exists {
continue
}
value, ok := stringifyIdentityAttrValue(cfg.SelfIdentity[key])
if !ok {
continue
}
if name == "username" {
value = normalizeIdentityUsername(value)
if strings.TrimSpace(value) == "" {
continue
}
}
attrs = append(attrs, identityAttr{Name: name, Value: value})
seen[name] = struct{}{}
}
if externalIdentity := strings.TrimSpace(cfg.ExternalIdentity); externalIdentity != "" {
if _, exists := seen["external_identity"]; !exists {
attrs = append(attrs, identityAttr{Name: "external_identity", Value: externalIdentity})
}
}
if len(attrs) == 1 {
return ""
}
var sb strings.Builder
sb.WriteString("<identity")
for _, attr := range attrs {
sb.WriteByte(' ')
sb.WriteString(attr.Name)
sb.WriteString(`="`)
sb.WriteString(escapeIdentityAttrValue(attr.Value))
sb.WriteByte('"')
}
sb.WriteString("/>")
return sb.String()
}
func normalizeIdentityAttrName(name string) (string, bool) {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return "", false
}
var sb strings.Builder
for _, r := range trimmed {
switch {
case unicode.IsLetter(r), unicode.IsDigit(r), r == '_', r == '-', r == '.':
sb.WriteRune(r)
default:
sb.WriteByte('_')
}
}
normalized := strings.Trim(sb.String(), "_-.")
if normalized == "" {
return "", false
}
first, _ := utf8.DecodeRuneInString(normalized)
if !unicode.IsLetter(first) && first != '_' {
normalized = "attr_" + normalized
}
if strings.HasPrefix(strings.ToLower(normalized), "xml") {
normalized = "attr_" + normalized
}
return normalized, true
}
func stringifyIdentityAttrValue(value any) (string, bool) {
switch v := value.(type) {
case nil:
return "", false
case string:
s := strings.TrimSpace(v)
return s, s != ""
case json.Number:
s := strings.TrimSpace(v.String())
return s, s != ""
case fmt.Stringer:
s := strings.TrimSpace(v.String())
return s, s != ""
case bool:
return strconv.FormatBool(v), true
case int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64:
return fmt.Sprint(v), true
default:
data, err := json.Marshal(v)
if err != nil {
s := strings.TrimSpace(fmt.Sprint(v))
return s, s != ""
}
s := strings.TrimSpace(string(data))
if s == "" || s == "null" {
return "", false
}
return s, true
}
}
func normalizeIdentityUsername(username string) string {
trimmed := strings.TrimSpace(username)
if trimmed == "" {
return ""
}
if strings.HasPrefix(trimmed, "@") {
return trimmed
}
return "@" + trimmed
}
func escapeIdentityAttrValue(value string) string {
replacer := strings.NewReplacer(
"&", "&amp;",
`"`, "&quot;",
"<", "&lt;",
">", "&gt;",
"'", "&apos;",
)
return replacer.Replace(value)
}
@@ -0,0 +1,73 @@
package flow
import (
"strings"
"testing"
"github.com/memohai/memoh/internal/channel"
)
func TestBuildPlatformIdentitiesXML(t *testing.T) {
t.Parallel()
configs := []channel.ChannelConfig{
{
ID: "tg-1",
ChannelType: channel.ChannelTypeTelegram,
ExternalIdentity: "12345",
SelfIdentity: map[string]any{
"user_id": "12345",
"username": "memoh_bot",
},
},
{
ID: "discord-1",
ChannelType: channel.ChannelTypeDiscord,
ExternalIdentity: "98765",
SelfIdentity: map[string]any{
"name": "Memoh & Co",
"username": "@memoh",
},
},
}
got := buildPlatformIdentitiesXML(configs)
want := strings.Join([]string{
`<identity channel="discord" name="Memoh &amp; Co" username="@memoh" external_identity="98765"/>`,
`<identity channel="telegram" user_id="12345" username="@memoh_bot" external_identity="12345"/>`,
}, "\n")
if got != want {
t.Fatalf("unexpected XML:\nwant:\n%s\n\ngot:\n%s", want, got)
}
}
func TestBuildPlatformIdentityLineNormalizesAttrs(t *testing.T) {
t.Parallel()
got := buildPlatformIdentityLine(channel.ChannelConfig{
ChannelType: channel.ChannelTypeTelegram,
SelfIdentity: map[string]any{
"123id": 7,
"display name": `Memoh <Bot>`,
"username": "memoh",
"xml_name": "reserved",
},
})
want := `<identity channel="telegram" attr_123id="7" display_name="Memoh &lt;Bot&gt;" username="@memoh" attr_xml_name="reserved"/>`
if got != want {
t.Fatalf("unexpected identity line:\nwant: %s\ngot: %s", want, got)
}
}
func TestBuildPlatformIdentitiesSectionSkipsEmptyConfigs(t *testing.T) {
t.Parallel()
got := buildPlatformIdentitiesSection([]channel.ChannelConfig{{
ID: "local-1",
ChannelType: channel.ChannelTypeLocal,
}})
if got != "" {
t.Fatalf("expected empty section, got %q", got)
}
}
+31 -6
View File
@@ -22,6 +22,7 @@ import (
"github.com/memohai/memoh/internal/accounts"
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/agent/background"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db/sqlc"
@@ -63,6 +64,10 @@ type gatewayAssetLoader interface {
OpenForGateway(ctx context.Context, botID, contentHash string) (reader io.ReadCloser, mime string, err error)
}
type botChannelConfigReader interface {
ListBotConfigs(ctx context.Context, botID string) ([]channel.ChannelConfig, error)
}
// Resolver orchestrates chat with the internal agent.
type Resolver struct {
agent *agentpkg.Agent
@@ -79,6 +84,7 @@ type Resolver struct {
eventPublisher messageevent.Publisher
skillLoader SkillLoader
assetLoader gatewayAssetLoader
channelStore botChannelConfigReader
pipeline *pipelinepkg.Pipeline
streamHTTPClient *http.Client
bgManager *background.Manager
@@ -161,6 +167,12 @@ func (r *Resolver) SetGatewayAssetLoader(loader gatewayAssetLoader) {
r.assetLoader = loader
}
// SetChannelStore configures the bot channel config store used to load
// platform identity metadata for system prompt generation.
func (r *Resolver) SetChannelStore(store botChannelConfigReader) {
r.channelStore = store
}
// SetCompactionService configures the compaction service for context compaction.
func (r *Resolver) SetCompactionService(s *compaction.Service) {
r.compactionService = s
@@ -642,13 +654,26 @@ func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig)
if cfg.Identity.TimezoneLocation != nil {
now = now.In(cfg.Identity.TimezoneLocation)
}
platformIdentitiesSection := ""
if r.channelStore != nil {
channelConfigs, err := r.channelStore.ListBotConfigs(ctx, cfg.Identity.BotID)
if err != nil {
r.logger.Warn("load bot platform identities failed",
slog.String("bot_id", cfg.Identity.BotID),
slog.Any("error", err),
)
} else {
platformIdentitiesSection = buildPlatformIdentitiesSection(channelConfigs)
}
}
cfg.System = agentpkg.GenerateSystemPrompt(agentpkg.SystemPromptParams{
SessionType: cfg.SessionType,
Skills: cfg.Skills,
Files: files,
Now: now,
Timezone: cfg.Identity.Timezone,
SupportsImageInput: supportsImageInput,
SessionType: cfg.SessionType,
Skills: cfg.Skills,
Files: files,
Now: now,
Timezone: cfg.Identity.Timezone,
SupportsImageInput: supportsImageInput,
PlatformIdentitiesSection: platformIdentitiesSection,
})
if cfg.Query != "" {