mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: enable memory search tool
This commit is contained in:
@@ -8,6 +8,7 @@ import { getWebTools } from './tools/web'
|
|||||||
import { subagentSystem } from './prompts/subagent'
|
import { subagentSystem } from './prompts/subagent'
|
||||||
import { getSubagentTools } from './tools/subagent'
|
import { getSubagentTools } from './tools/subagent'
|
||||||
import { getSkillTools } from './tools/skill'
|
import { getSkillTools } from './tools/skill'
|
||||||
|
import { getMemoryTools } from './tools/memory'
|
||||||
|
|
||||||
export enum AgentAction {
|
export enum AgentAction {
|
||||||
WebSearch = 'web_search',
|
WebSearch = 'web_search',
|
||||||
@@ -15,6 +16,7 @@ export enum AgentAction {
|
|||||||
Subagent = 'subagent',
|
Subagent = 'subagent',
|
||||||
Schedule = 'schedule',
|
Schedule = 'schedule',
|
||||||
Skill = 'skill',
|
Skill = 'skill',
|
||||||
|
Memory = 'memory',
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AgentParams extends BaseModelConfig {
|
export interface AgentParams extends BaseModelConfig {
|
||||||
@@ -98,6 +100,11 @@ export const createAgent = (
|
|||||||
})
|
})
|
||||||
Object.assign(tools, subagentTools)
|
Object.assign(tools, subagentTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (allowedActions.includes(AgentAction.Memory)) {
|
||||||
|
const memoryTools = getMemoryTools({ fetch: fetcher })
|
||||||
|
Object.assign(tools, memoryTools)
|
||||||
|
}
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,16 @@ export type MemoryToolParams = {
|
|||||||
fetch: AuthFetcher
|
fetch: AuthFetcher
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MemorySearchItem = {
|
||||||
|
id?: string
|
||||||
|
memory?: string
|
||||||
|
score?: number
|
||||||
|
createdAt?: string
|
||||||
|
metadata?: {
|
||||||
|
source?: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export const getMemoryTools = ({ fetch }: MemoryToolParams) => {
|
export const getMemoryTools = ({ fetch }: MemoryToolParams) => {
|
||||||
const searchMemory = tool({
|
const searchMemory = tool({
|
||||||
description: 'Search for memories',
|
description: 'Search for memories',
|
||||||
@@ -13,8 +23,29 @@ export const getMemoryTools = ({ fetch }: MemoryToolParams) => {
|
|||||||
query: z.string().describe('The query to search for memories'),
|
query: z.string().describe('The query to search for memories'),
|
||||||
}),
|
}),
|
||||||
execute: async ({ query }) => {
|
execute: async ({ query }) => {
|
||||||
const response = await fetch(`/memory/search?query=${query}`)
|
const response = await fetch('/memory/search', {
|
||||||
return response.json()
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
query,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
const data = await response.json()
|
||||||
|
const results = Array.isArray(data?.results)
|
||||||
|
? (data.results as MemorySearchItem[])
|
||||||
|
: []
|
||||||
|
const simplified = results.map((item) => ({
|
||||||
|
id: item?.id,
|
||||||
|
memory: item?.memory,
|
||||||
|
score: item?.score,
|
||||||
|
}))
|
||||||
|
return {
|
||||||
|
query,
|
||||||
|
total: simplified.length,
|
||||||
|
results: simplified,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -625,13 +625,8 @@ func (r *Resolver) storeMemory(ctx context.Context, userID, query string, respon
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userMessage := GatewayMessage{
|
memoryMessages := make([]memory.Message, 0, len(responseMessages))
|
||||||
"role": "user",
|
for _, msg := range responseMessages {
|
||||||
"content": query,
|
|
||||||
}
|
|
||||||
messages := append([]GatewayMessage{userMessage}, responseMessages...)
|
|
||||||
memoryMessages := make([]memory.Message, 0, len(messages))
|
|
||||||
for _, msg := range messages {
|
|
||||||
role, content := gatewayMessageToMemory(msg)
|
role, content := gatewayMessageToMemory(msg)
|
||||||
if strings.TrimSpace(content) == "" {
|
if strings.TrimSpace(content) == "" {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ sql:
|
|||||||
- engine: "postgresql"
|
- engine: "postgresql"
|
||||||
schema:
|
schema:
|
||||||
- "db/migrations/0001_init.up.sql"
|
- "db/migrations/0001_init.up.sql"
|
||||||
- "db/migrations/0002_channel.up.sql"
|
# - "db/migrations/0002_channel.up.sql"
|
||||||
queries: "db/queries"
|
queries: "db/queries"
|
||||||
gen:
|
gen:
|
||||||
go:
|
go:
|
||||||
|
|||||||
Reference in New Issue
Block a user