diff --git a/agent/src/agent.ts b/agent/src/agent.ts index aeb053ee..4e3586b1 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -8,6 +8,7 @@ import { getWebTools } from './tools/web' import { subagentSystem } from './prompts/subagent' import { getSubagentTools } from './tools/subagent' import { getSkillTools } from './tools/skill' +import { getMemoryTools } from './tools/memory' export enum AgentAction { WebSearch = 'web_search', @@ -15,6 +16,7 @@ export enum AgentAction { Subagent = 'subagent', Schedule = 'schedule', Skill = 'skill', + Memory = 'memory', } export interface AgentParams extends BaseModelConfig { @@ -98,6 +100,11 @@ export const createAgent = ( }) Object.assign(tools, subagentTools) } + + if (allowedActions.includes(AgentAction.Memory)) { + const memoryTools = getMemoryTools({ fetch: fetcher }) + Object.assign(tools, memoryTools) + } return tools } diff --git a/agent/src/tools/memory.ts b/agent/src/tools/memory.ts index dd8d96c8..3936fae9 100644 --- a/agent/src/tools/memory.ts +++ b/agent/src/tools/memory.ts @@ -6,6 +6,16 @@ export type MemoryToolParams = { fetch: AuthFetcher } +type MemorySearchItem = { + id?: string + memory?: string + score?: number + createdAt?: string + metadata?: { + source?: string + } +} + export const getMemoryTools = ({ fetch }: MemoryToolParams) => { const searchMemory = tool({ description: 'Search for memories', @@ -13,8 +23,29 @@ export const getMemoryTools = ({ fetch }: MemoryToolParams) => { query: z.string().describe('The query to search for memories'), }), execute: async ({ query }) => { - const response = await fetch(`/memory/search?query=${query}`) - return response.json() + const response = await fetch('/memory/search', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + query, + }), + }) + const data = await response.json() + const results = Array.isArray(data?.results) + ? (data.results as MemorySearchItem[]) + : [] + const simplified = results.map((item) => ({ + id: item?.id, + memory: item?.memory, + score: item?.score, + })) + return { + query, + total: simplified.length, + results: simplified, + } }, }) diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 79ecf9b7..8d65bce4 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -625,13 +625,8 @@ func (r *Resolver) storeMemory(ctx context.Context, userID, query string, respon return nil } - userMessage := GatewayMessage{ - "role": "user", - "content": query, - } - messages := append([]GatewayMessage{userMessage}, responseMessages...) - memoryMessages := make([]memory.Message, 0, len(messages)) - for _, msg := range messages { + memoryMessages := make([]memory.Message, 0, len(responseMessages)) + for _, msg := range responseMessages { role, content := gatewayMessageToMemory(msg) if strings.TrimSpace(content) == "" { continue diff --git a/sqlc.yaml b/sqlc.yaml index b3c2a82e..7337cc80 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -3,7 +3,7 @@ sql: - engine: "postgresql" schema: - "db/migrations/0001_init.up.sql" - - "db/migrations/0002_channel.up.sql" + # - "db/migrations/0002_channel.up.sql" queries: "db/queries" gen: go: