feat(agent): include platform self identities in system prompts

Give bots their known per-channel account handles in the system prompt so they can reason about platform-specific self references consistently. Reuse persisted channel self_identity data across chat, discuss, schedule, heartbeat, and subagent prompts.
This commit is contained in:
Acbox
2026-04-16 16:44:29 +08:00
parent e0fc2f514e
commit 9f10033f63
13 changed files with 409 additions and 19 deletions
@@ -0,0 +1,207 @@
package flow
import (
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/memohai/memoh/internal/channel"
)
const platformIdentitiesIntro = "## Platform Identities\n\nThese XML tags describe your own known account identities across connected platforms.\n"
type identityAttr struct {
Name string
Value string
}
func buildPlatformIdentitiesSection(configs []channel.ChannelConfig) string {
xmlBlock := buildPlatformIdentitiesXML(configs)
if xmlBlock == "" {
return ""
}
return platformIdentitiesIntro + "\n" + xmlBlock
}
func buildPlatformIdentitiesXML(configs []channel.ChannelConfig) string {
if len(configs) == 0 {
return ""
}
sorted := make([]channel.ChannelConfig, len(configs))
copy(sorted, configs)
sort.Slice(sorted, func(i, j int) bool {
left := sorted[i]
right := sorted[j]
if cmp := strings.Compare(left.ChannelType.String(), right.ChannelType.String()); cmp != 0 {
return cmp < 0
}
if cmp := strings.Compare(left.ExternalIdentity, right.ExternalIdentity); cmp != 0 {
return cmp < 0
}
return strings.Compare(left.ID, right.ID) < 0
})
lines := make([]string, 0, len(sorted))
for _, cfg := range sorted {
line := buildPlatformIdentityLine(cfg)
if line == "" {
continue
}
lines = append(lines, line)
}
return strings.Join(lines, "\n")
}
func buildPlatformIdentityLine(cfg channel.ChannelConfig) string {
channelName := strings.TrimSpace(cfg.ChannelType.String())
if channelName == "" {
return ""
}
attrs := []identityAttr{{
Name: "channel",
Value: channelName,
}}
keys := make([]string, 0, len(cfg.SelfIdentity))
for key := range cfg.SelfIdentity {
keys = append(keys, key)
}
sort.Strings(keys)
seen := map[string]struct{}{
"channel": {},
}
for _, key := range keys {
name, ok := normalizeIdentityAttrName(key)
if !ok {
continue
}
if _, exists := seen[name]; exists {
continue
}
value, ok := stringifyIdentityAttrValue(cfg.SelfIdentity[key])
if !ok {
continue
}
if name == "username" {
value = normalizeIdentityUsername(value)
if strings.TrimSpace(value) == "" {
continue
}
}
attrs = append(attrs, identityAttr{Name: name, Value: value})
seen[name] = struct{}{}
}
if externalIdentity := strings.TrimSpace(cfg.ExternalIdentity); externalIdentity != "" {
if _, exists := seen["external_identity"]; !exists {
attrs = append(attrs, identityAttr{Name: "external_identity", Value: externalIdentity})
}
}
if len(attrs) == 1 {
return ""
}
var sb strings.Builder
sb.WriteString("<identity")
for _, attr := range attrs {
sb.WriteByte(' ')
sb.WriteString(attr.Name)
sb.WriteString(`="`)
sb.WriteString(escapeIdentityAttrValue(attr.Value))
sb.WriteByte('"')
}
sb.WriteString("/>")
return sb.String()
}
func normalizeIdentityAttrName(name string) (string, bool) {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return "", false
}
var sb strings.Builder
for _, r := range trimmed {
switch {
case unicode.IsLetter(r), unicode.IsDigit(r), r == '_', r == '-', r == '.':
sb.WriteRune(r)
default:
sb.WriteByte('_')
}
}
normalized := strings.Trim(sb.String(), "_-.")
if normalized == "" {
return "", false
}
first, _ := utf8.DecodeRuneInString(normalized)
if !unicode.IsLetter(first) && first != '_' {
normalized = "attr_" + normalized
}
if strings.HasPrefix(strings.ToLower(normalized), "xml") {
normalized = "attr_" + normalized
}
return normalized, true
}
func stringifyIdentityAttrValue(value any) (string, bool) {
switch v := value.(type) {
case nil:
return "", false
case string:
s := strings.TrimSpace(v)
return s, s != ""
case json.Number:
s := strings.TrimSpace(v.String())
return s, s != ""
case fmt.Stringer:
s := strings.TrimSpace(v.String())
return s, s != ""
case bool:
return strconv.FormatBool(v), true
case int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64:
return fmt.Sprint(v), true
default:
data, err := json.Marshal(v)
if err != nil {
s := strings.TrimSpace(fmt.Sprint(v))
return s, s != ""
}
s := strings.TrimSpace(string(data))
if s == "" || s == "null" {
return "", false
}
return s, true
}
}
func normalizeIdentityUsername(username string) string {
trimmed := strings.TrimSpace(username)
if trimmed == "" {
return ""
}
if strings.HasPrefix(trimmed, "@") {
return trimmed
}
return "@" + trimmed
}
func escapeIdentityAttrValue(value string) string {
replacer := strings.NewReplacer(
"&", "&amp;",
`"`, "&quot;",
"<", "&lt;",
">", "&gt;",
"'", "&apos;",
)
return replacer.Replace(value)
}
@@ -0,0 +1,73 @@
package flow
import (
"strings"
"testing"
"github.com/memohai/memoh/internal/channel"
)
func TestBuildPlatformIdentitiesXML(t *testing.T) {
t.Parallel()
configs := []channel.ChannelConfig{
{
ID: "tg-1",
ChannelType: channel.ChannelTypeTelegram,
ExternalIdentity: "12345",
SelfIdentity: map[string]any{
"user_id": "12345",
"username": "memoh_bot",
},
},
{
ID: "discord-1",
ChannelType: channel.ChannelTypeDiscord,
ExternalIdentity: "98765",
SelfIdentity: map[string]any{
"name": "Memoh & Co",
"username": "@memoh",
},
},
}
got := buildPlatformIdentitiesXML(configs)
want := strings.Join([]string{
`<identity channel="discord" name="Memoh &amp; Co" username="@memoh" external_identity="98765"/>`,
`<identity channel="telegram" user_id="12345" username="@memoh_bot" external_identity="12345"/>`,
}, "\n")
if got != want {
t.Fatalf("unexpected XML:\nwant:\n%s\n\ngot:\n%s", want, got)
}
}
func TestBuildPlatformIdentityLineNormalizesAttrs(t *testing.T) {
t.Parallel()
got := buildPlatformIdentityLine(channel.ChannelConfig{
ChannelType: channel.ChannelTypeTelegram,
SelfIdentity: map[string]any{
"123id": 7,
"display name": `Memoh <Bot>`,
"username": "memoh",
"xml_name": "reserved",
},
})
want := `<identity channel="telegram" attr_123id="7" display_name="Memoh &lt;Bot&gt;" username="@memoh" attr_xml_name="reserved"/>`
if got != want {
t.Fatalf("unexpected identity line:\nwant: %s\ngot: %s", want, got)
}
}
func TestBuildPlatformIdentitiesSectionSkipsEmptyConfigs(t *testing.T) {
t.Parallel()
got := buildPlatformIdentitiesSection([]channel.ChannelConfig{{
ID: "local-1",
ChannelType: channel.ChannelTypeLocal,
}})
if got != "" {
t.Fatalf("expected empty section, got %q", got)
}
}
+31 -6
View File
@@ -22,6 +22,7 @@ import (
"github.com/memohai/memoh/internal/accounts"
agentpkg "github.com/memohai/memoh/internal/agent"
"github.com/memohai/memoh/internal/agent/background"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/compaction"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db/sqlc"
@@ -63,6 +64,10 @@ type gatewayAssetLoader interface {
OpenForGateway(ctx context.Context, botID, contentHash string) (reader io.ReadCloser, mime string, err error)
}
type botChannelConfigReader interface {
ListBotConfigs(ctx context.Context, botID string) ([]channel.ChannelConfig, error)
}
// Resolver orchestrates chat with the internal agent.
type Resolver struct {
agent *agentpkg.Agent
@@ -79,6 +84,7 @@ type Resolver struct {
eventPublisher messageevent.Publisher
skillLoader SkillLoader
assetLoader gatewayAssetLoader
channelStore botChannelConfigReader
pipeline *pipelinepkg.Pipeline
streamHTTPClient *http.Client
bgManager *background.Manager
@@ -161,6 +167,12 @@ func (r *Resolver) SetGatewayAssetLoader(loader gatewayAssetLoader) {
r.assetLoader = loader
}
// SetChannelStore configures the bot channel config store used to load
// platform identity metadata for system prompt generation.
func (r *Resolver) SetChannelStore(store botChannelConfigReader) {
r.channelStore = store
}
// SetCompactionService configures the compaction service for context compaction.
func (r *Resolver) SetCompactionService(s *compaction.Service) {
r.compactionService = s
@@ -642,13 +654,26 @@ func (r *Resolver) prepareRunConfig(ctx context.Context, cfg agentpkg.RunConfig)
if cfg.Identity.TimezoneLocation != nil {
now = now.In(cfg.Identity.TimezoneLocation)
}
platformIdentitiesSection := ""
if r.channelStore != nil {
channelConfigs, err := r.channelStore.ListBotConfigs(ctx, cfg.Identity.BotID)
if err != nil {
r.logger.Warn("load bot platform identities failed",
slog.String("bot_id", cfg.Identity.BotID),
slog.Any("error", err),
)
} else {
platformIdentitiesSection = buildPlatformIdentitiesSection(channelConfigs)
}
}
cfg.System = agentpkg.GenerateSystemPrompt(agentpkg.SystemPromptParams{
SessionType: cfg.SessionType,
Skills: cfg.Skills,
Files: files,
Now: now,
Timezone: cfg.Identity.Timezone,
SupportsImageInput: supportsImageInput,
SessionType: cfg.SessionType,
Skills: cfg.Skills,
Files: files,
Now: now,
Timezone: cfg.Identity.Timezone,
SupportsImageInput: supportsImageInput,
PlatformIdentitiesSection: platformIdentitiesSection,
})
if cfg.Query != "" {