mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
+43
-1
@@ -1,12 +1,13 @@
|
||||
import { generateText, ModelMessage, stepCountIs, streamText, TextStreamPart, ToolSet } from 'ai'
|
||||
import { createChatGateway } from './gateway'
|
||||
import { BaseModelConfig, Schedule } from './types'
|
||||
import { AgentSkill, BaseModelConfig, Schedule } from './types'
|
||||
import { system, schedule } from './prompts'
|
||||
import { AuthFetcher } from './index'
|
||||
import { getScheduleTools } from './tools/schedule'
|
||||
import { getWebTools } from './tools/web'
|
||||
import { subagentSystem } from './prompts/subagent'
|
||||
import { getSubagentTools } from './tools/subagent'
|
||||
import { getSkillTools } from './tools/skill'
|
||||
|
||||
export enum AgentAction {
|
||||
WebSearch = 'web_search',
|
||||
@@ -25,6 +26,8 @@ export interface AgentParams extends BaseModelConfig {
|
||||
currentPlatform?: string
|
||||
braveApiKey?: string
|
||||
braveBaseUrl?: string
|
||||
skills?: AgentSkill[]
|
||||
useSkills?: string[]
|
||||
allowed?: AgentAction[]
|
||||
}
|
||||
|
||||
@@ -35,6 +38,7 @@ export interface AgentInput {
|
||||
|
||||
export interface AgentResult {
|
||||
messages: ModelMessage[]
|
||||
skills: string[]
|
||||
}
|
||||
|
||||
export const createAgent = (
|
||||
@@ -43,6 +47,10 @@ export const createAgent = (
|
||||
) => {
|
||||
const gateway = createChatGateway(params.clientType)
|
||||
const messages: ModelMessage[] = []
|
||||
const enabledSkills: AgentSkill[] = params.skills ?? []
|
||||
enabledSkills.push(
|
||||
...params.useSkills?.map((name) => params.skills?.find((s) => s.name === name)
|
||||
).filter((s) => s !== undefined) ?? [])
|
||||
|
||||
const allowedActions = params.allowed
|
||||
?? Object.values(AgentAction)
|
||||
@@ -52,6 +60,19 @@ export const createAgent = (
|
||||
const getTools = () => {
|
||||
const tools: ToolSet = {}
|
||||
|
||||
if (allowedActions.includes(AgentAction.Skill)) {
|
||||
const skillTools = getSkillTools({
|
||||
skills: params.skills ?? [],
|
||||
useSkill: (skill) => {
|
||||
if (enabledSkills.some((s) => s.name === skill.name)) {
|
||||
return
|
||||
}
|
||||
enabledSkills.push(skill)
|
||||
}
|
||||
})
|
||||
Object.assign(tools, skillTools)
|
||||
}
|
||||
|
||||
if (allowedActions.includes(AgentAction.Schedule)) {
|
||||
const scheduleTools = getScheduleTools({ fetch: fetcher })
|
||||
Object.assign(tools, scheduleTools)
|
||||
@@ -89,6 +110,8 @@ export const createAgent = (
|
||||
maxContextLoadTime: params.maxContextLoadTime ?? 1550,
|
||||
platforms: params.platforms ?? [],
|
||||
currentPlatform: params.currentPlatform,
|
||||
skills: params.skills ?? [],
|
||||
enabledSkills,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -107,10 +130,16 @@ export const createAgent = (
|
||||
system: generateSystem(),
|
||||
stopWhen: stepCountIs(maxSteps),
|
||||
messages,
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: generateSystem(),
|
||||
}
|
||||
},
|
||||
tools: getTools(),
|
||||
})
|
||||
return {
|
||||
messages: [user, ...response.messages],
|
||||
skills: enabledSkills.map((s) => s.name),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +173,7 @@ export const createAgent = (
|
||||
})
|
||||
return {
|
||||
messages: [user, ...response.messages],
|
||||
skills: enabledSkills.map((s) => s.name),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,6 +192,11 @@ export const createAgent = (
|
||||
system: generateSystem(),
|
||||
stopWhen: stepCountIs(maxSteps),
|
||||
messages,
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: generateSystem(),
|
||||
}
|
||||
},
|
||||
tools: getTools(),
|
||||
})
|
||||
for await (const event of fullStream) {
|
||||
@@ -169,6 +204,7 @@ export const createAgent = (
|
||||
}
|
||||
return {
|
||||
messages: [user, ...(await response).messages],
|
||||
skills: enabledSkills.map((s) => s.name),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,10 +230,16 @@ export const createAgent = (
|
||||
system: generateSystem(),
|
||||
stopWhen: stepCountIs(maxSteps),
|
||||
messages,
|
||||
prepareStep: () => {
|
||||
return {
|
||||
system: generateSystem(),
|
||||
}
|
||||
},
|
||||
tools: getTools(),
|
||||
})
|
||||
return {
|
||||
messages: [user, ...response.messages],
|
||||
skills: enabledSkills.map((s) => s.name),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,12 @@ import { ModelMessage } from 'ai'
|
||||
import { bearerMiddleware } from '../middlewares/bearer'
|
||||
import { loadConfig } from '../config'
|
||||
|
||||
const Skill = z.object({
|
||||
name: z.string().min(1, 'Skill name is required'),
|
||||
description: z.string().min(1, 'Skill description is required'),
|
||||
content: z.string().min(1, 'Skill content is required'),
|
||||
})
|
||||
|
||||
const ChatBody = z.object({
|
||||
apiKey: z.string().min(1, 'API key is required'),
|
||||
baseUrl: z.string().min(1, 'Base URL is required'),
|
||||
@@ -22,6 +28,8 @@ const ChatBody = z.object({
|
||||
maxContextLoadTime: z.number().min(1, 'Max context load time is required'),
|
||||
platforms: z.array(z.string()).optional(),
|
||||
currentPlatform: z.string().optional(),
|
||||
skills: z.array(Skill).optional(),
|
||||
useSkills: z.array(z.string()).optional(),
|
||||
|
||||
messages: z.array(z.any()),
|
||||
query: z.string().min(1, 'Query is required'),
|
||||
@@ -56,6 +64,8 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
currentPlatform: body.currentPlatform,
|
||||
braveApiKey: config.brave?.api_key,
|
||||
braveBaseUrl: config.brave?.base_url,
|
||||
skills: body.skills,
|
||||
useSkills: body.useSkills,
|
||||
}, createAuthFetcher(bearer))
|
||||
try {
|
||||
const result = await ask({
|
||||
@@ -98,6 +108,8 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
currentPlatform: body.currentPlatform,
|
||||
braveApiKey: config.brave?.api_key,
|
||||
braveBaseUrl: config.brave?.base_url,
|
||||
skills: body.skills,
|
||||
useSkills: body.useSkills,
|
||||
}, createAuthFetcher(bearer))
|
||||
try {
|
||||
const streanGenerator = stream({
|
||||
@@ -151,6 +163,8 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
currentPlatform: body.currentPlatform,
|
||||
braveApiKey: config.brave?.api_key,
|
||||
braveBaseUrl: config.brave?.base_url,
|
||||
skills: body.skills,
|
||||
useSkills: body.useSkills,
|
||||
}, createAuthFetcher(bearer))
|
||||
try {
|
||||
return await triggerSchedule({
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { time } from './shared'
|
||||
import { quote } from './utils'
|
||||
import { AgentSkill } from '../types'
|
||||
|
||||
export interface SystemParams {
|
||||
date: Date
|
||||
@@ -8,9 +9,20 @@ export interface SystemParams {
|
||||
maxContextLoadTime: number
|
||||
platforms: string[]
|
||||
currentPlatform?: string
|
||||
skills: AgentSkill[]
|
||||
enabledSkills: AgentSkill[]
|
||||
}
|
||||
|
||||
export const system = ({ date, locale, language, maxContextLoadTime, platforms, currentPlatform }: SystemParams) => {
|
||||
export const skillPrompt = (skill: AgentSkill) => {
|
||||
return `
|
||||
### ${skill.name}
|
||||
> ${skill.description}
|
||||
|
||||
${skill.content}
|
||||
`.trim()
|
||||
}
|
||||
|
||||
export const system = ({ date, locale, language, maxContextLoadTime, platforms, currentPlatform, skills, enabledSkills }: SystemParams) => {
|
||||
return `
|
||||
---
|
||||
${time({ date, locale })}
|
||||
@@ -57,5 +69,14 @@ When a task is large, you can create a Subagent to help you complete some tasks
|
||||
+ The ${quote('name')} is the name of the subagent to ask.
|
||||
+ The ${quote('query')} is the prompt to ask the subagent to complete the task.
|
||||
Before asking a subagent, you should first create a subagent if it does not exist.
|
||||
|
||||
**Skills**
|
||||
|
||||
There are ${skills.length} skills available, you can use ${quote('use_skill')} to use a skill.
|
||||
${skills.map(skill => `- ${skill.name}: ${skill.description}`).join('\n')}
|
||||
|
||||
**Enabled Skills**
|
||||
|
||||
${enabledSkills.map(skill => skillPrompt(skill)).join('\n\n---\n\n')}
|
||||
`.trim()
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
import { AgentSkill } from '../types'
|
||||
import { tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
interface SkillToolParams {
|
||||
skills: AgentSkill[]
|
||||
useSkill: (skill: AgentSkill, reason: string) => void
|
||||
}
|
||||
|
||||
export const getSkillTools = ({ skills, useSkill }: SkillToolParams) => {
|
||||
const useSkillTool = tool({
|
||||
description: 'Use a skill if you think it is relevant to the current task',
|
||||
inputSchema: z.object({
|
||||
skillName: z.string().describe('The name of the skill to use'),
|
||||
reason: z.string().describe('The reason why you think this skill is relevant to the current task'),
|
||||
}),
|
||||
execute: async ({ skillName, reason }) => {
|
||||
const skill = skills.find((s) => s.name === skillName)
|
||||
if (!skill) {
|
||||
return { error: 'Skill not found' }
|
||||
}
|
||||
await useSkill(skill, reason)
|
||||
return {
|
||||
success: true,
|
||||
skillName,
|
||||
reason,
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
'use_skill': useSkillTool,
|
||||
}
|
||||
}
|
||||
@@ -18,4 +18,10 @@ export interface Schedule {
|
||||
pattern: string
|
||||
maxCalls?: number
|
||||
command: string
|
||||
}
|
||||
|
||||
export interface AgentSkill {
|
||||
name: string
|
||||
description: string
|
||||
content: string
|
||||
}
|
||||
@@ -134,6 +134,7 @@ CREATE INDEX IF NOT EXISTS idx_lifecycle_events_event_type ON lifecycle_events(e
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
messages JSONB NOT NULL,
|
||||
skills TEXT[] NOT NULL DEFAULT '{}'::text[],
|
||||
timestamp TIMESTAMPTZ NOT NULL,
|
||||
"user" UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
-- name: CreateHistory :one
|
||||
INSERT INTO history (messages, timestamp, "user")
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, messages, timestamp, "user";
|
||||
INSERT INTO history (messages, skills, timestamp, "user")
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, messages, skills, timestamp, "user";
|
||||
|
||||
-- name: ListHistoryByUserSince :many
|
||||
SELECT id, messages, timestamp, "user"
|
||||
SELECT id, messages, skills, timestamp, "user"
|
||||
FROM history
|
||||
WHERE "user" = $1 AND timestamp >= $2
|
||||
ORDER BY timestamp ASC;
|
||||
|
||||
-- name: GetHistoryByID :one
|
||||
SELECT id, messages, timestamp, "user"
|
||||
SELECT id, messages, skills, timestamp, "user"
|
||||
FROM history
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: ListHistoryByUser :many
|
||||
SELECT id, messages, timestamp, "user"
|
||||
SELECT id, messages, skills, timestamp, "user"
|
||||
FROM history
|
||||
WHERE "user" = $1
|
||||
ORDER BY timestamp DESC
|
||||
|
||||
@@ -2594,6 +2594,20 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"chat.AgentSkill": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.ChatRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -2632,6 +2646,18 @@ const docTemplate = `{
|
||||
},
|
||||
"query": {
|
||||
"type": "string"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/chat.AgentSkill"
|
||||
}
|
||||
},
|
||||
"use_skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -2649,6 +2675,12 @@ const docTemplate = `{
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -2981,6 +3013,12 @@ const docTemplate = `{
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -3008,6 +3046,12 @@ const docTemplate = `{
|
||||
"additionalProperties": true
|
||||
}
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
@@ -2585,6 +2585,20 @@
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"chat.AgentSkill": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat.ChatRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -2623,6 +2637,18 @@
|
||||
},
|
||||
"query": {
|
||||
"type": "string"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/chat.AgentSkill"
|
||||
}
|
||||
},
|
||||
"use_skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -2640,6 +2666,12 @@
|
||||
},
|
||||
"provider": {
|
||||
"type": "string"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -2972,6 +3004,12 @@
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -2999,6 +3037,12 @@
|
||||
"additionalProperties": true
|
||||
}
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
definitions:
|
||||
chat.AgentSkill:
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
chat.ChatRequest:
|
||||
properties:
|
||||
current_platform:
|
||||
@@ -25,6 +34,14 @@ definitions:
|
||||
type: string
|
||||
query:
|
||||
type: string
|
||||
skills:
|
||||
items:
|
||||
$ref: '#/definitions/chat.AgentSkill'
|
||||
type: array
|
||||
use_skills:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
type: object
|
||||
chat.ChatResponse:
|
||||
properties:
|
||||
@@ -36,6 +53,10 @@ definitions:
|
||||
type: string
|
||||
provider:
|
||||
type: string
|
||||
skills:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
type: object
|
||||
chat.GatewayMessage:
|
||||
additionalProperties: true
|
||||
@@ -251,6 +272,10 @@ definitions:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
type: array
|
||||
skills:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
type: object
|
||||
history.ListResponse:
|
||||
properties:
|
||||
@@ -268,6 +293,10 @@ definitions:
|
||||
additionalProperties: true
|
||||
type: object
|
||||
type: array
|
||||
skills:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
timestamp:
|
||||
type: string
|
||||
user_id:
|
||||
|
||||
@@ -88,16 +88,22 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
|
||||
}
|
||||
|
||||
var messages []GatewayMessage
|
||||
var historySkills []string
|
||||
if !skipHistory {
|
||||
messages, err = r.loadHistoryMessages(ctx, req.UserID, maxContextLoadTime)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
historySkills, err = r.loadHistorySkills(ctx, req.UserID, maxContextLoadTime)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
}
|
||||
if len(req.Messages) > 0 {
|
||||
messages = append(messages, req.Messages...)
|
||||
}
|
||||
messages = sanitizeGatewayMessages(messages)
|
||||
useSkills := normalizeSkills(append(historySkills, req.UseSkills...))
|
||||
|
||||
payload := agentGatewayRequest{
|
||||
APIKey: provider.ApiKey,
|
||||
@@ -112,6 +118,8 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
|
||||
CurrentPlatform: req.CurrentPlatform,
|
||||
Messages: messages,
|
||||
Query: req.Query,
|
||||
Skills: req.Skills,
|
||||
UseSkills: useSkills,
|
||||
}
|
||||
payload.Language = language
|
||||
|
||||
@@ -120,7 +128,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
|
||||
if err := r.storeHistory(ctx, req.UserID, req.Query, resp.Messages); err != nil {
|
||||
if err := r.storeHistory(ctx, req.UserID, req.Query, resp.Messages, resp.Skills); err != nil {
|
||||
return ChatResponse{}, err
|
||||
}
|
||||
if err := r.storeMemory(ctx, req.UserID, req.Query, resp.Messages); err != nil {
|
||||
@@ -129,6 +137,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err
|
||||
|
||||
return ChatResponse{
|
||||
Messages: resp.Messages,
|
||||
Skills: resp.Skills,
|
||||
Model: chatModel.ModelID,
|
||||
Provider: provider.ClientType,
|
||||
}, nil
|
||||
@@ -166,6 +175,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
historySkills, err := r.loadHistorySkills(ctx, userID, maxContextLoadTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
useSkills := normalizeSkills(historySkills)
|
||||
|
||||
payload := agentGatewayScheduleRequest{
|
||||
APIKey: provider.ApiKey,
|
||||
@@ -181,13 +195,14 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule
|
||||
Messages: messages,
|
||||
Query: schedule.Command,
|
||||
Schedule: schedule,
|
||||
UseSkills: useSkills,
|
||||
}
|
||||
|
||||
resp, err := r.postSchedule(ctx, payload, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.storeHistory(ctx, userID, schedule.Command, resp.Messages); err != nil {
|
||||
if err := r.storeHistory(ctx, userID, schedule.Command, resp.Messages, resp.Skills); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.storeMemory(ctx, userID, schedule.Command, resp.Messages); err != nil {
|
||||
@@ -238,17 +253,24 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre
|
||||
}
|
||||
|
||||
var messages []GatewayMessage
|
||||
var historySkills []string
|
||||
if !skipHistory {
|
||||
messages, err = r.loadHistoryMessages(ctx, req.UserID, maxContextLoadTime)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
historySkills, err = r.loadHistorySkills(ctx, req.UserID, maxContextLoadTime)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(req.Messages) > 0 {
|
||||
messages = append(messages, req.Messages...)
|
||||
}
|
||||
messages = sanitizeGatewayMessages(messages)
|
||||
useSkills := normalizeSkills(append(historySkills, req.UseSkills...))
|
||||
|
||||
payload := agentGatewayRequest{
|
||||
APIKey: provider.ApiKey,
|
||||
@@ -263,6 +285,8 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre
|
||||
CurrentPlatform: req.CurrentPlatform,
|
||||
Messages: messages,
|
||||
Query: req.Query,
|
||||
Skills: req.Skills,
|
||||
UseSkills: useSkills,
|
||||
}
|
||||
payload.Language = language
|
||||
|
||||
@@ -288,6 +312,8 @@ type agentGatewayRequest struct {
|
||||
CurrentPlatform string `json:"currentPlatform,omitempty"`
|
||||
Messages []GatewayMessage `json:"messages"`
|
||||
Query string `json:"query"`
|
||||
Skills []AgentSkill `json:"skills,omitempty"`
|
||||
UseSkills []string `json:"useSkills,omitempty"`
|
||||
}
|
||||
|
||||
type agentGatewayScheduleRequest struct {
|
||||
@@ -304,10 +330,13 @@ type agentGatewayScheduleRequest struct {
|
||||
Messages []GatewayMessage `json:"messages"`
|
||||
Query string `json:"query"`
|
||||
Schedule SchedulePayload `json:"schedule"`
|
||||
Skills []AgentSkill `json:"skills,omitempty"`
|
||||
UseSkills []string `json:"useSkills,omitempty"`
|
||||
}
|
||||
|
||||
type agentGatewayResponse struct {
|
||||
Messages []GatewayMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
}
|
||||
|
||||
func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest, token string) (agentGatewayResponse, error) {
|
||||
@@ -475,7 +504,36 @@ func (r *Resolver) loadHistoryMessages(ctx context.Context, userID string, maxCo
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) storeHistory(ctx context.Context, userID, query string, responseMessages []GatewayMessage) error {
|
||||
func (r *Resolver) loadHistorySkills(ctx context.Context, userID string, maxContextLoadTime int) ([]string, error) {
|
||||
if r.queries == nil {
|
||||
return nil, fmt.Errorf("history queries not configured")
|
||||
}
|
||||
pgUserID, err := parseUUID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
from := time.Now().UTC().Add(-time.Duration(normalizeMaxContextLoad(maxContextLoadTime)) * time.Minute)
|
||||
rows, err := r.queries.ListHistoryByUserSince(ctx, sqlc.ListHistoryByUserSinceParams{
|
||||
User: pgUserID,
|
||||
Timestamp: pgtype.Timestamptz{
|
||||
Time: from,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
combined := make([]string, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
if len(row.Skills) == 0 {
|
||||
continue
|
||||
}
|
||||
combined = append(combined, row.Skills...)
|
||||
}
|
||||
return normalizeSkills(combined), nil
|
||||
}
|
||||
|
||||
func (r *Resolver) storeHistory(ctx context.Context, userID, query string, responseMessages []GatewayMessage, skills []string) error {
|
||||
if r.queries == nil {
|
||||
return fmt.Errorf("history queries not configured")
|
||||
}
|
||||
@@ -496,8 +554,10 @@ func (r *Resolver) storeHistory(ctx context.Context, userID, query string, respo
|
||||
if err := r.ensureUserExists(ctx, pgUserID); err != nil {
|
||||
return err
|
||||
}
|
||||
normalizedSkills := normalizeSkills(skills)
|
||||
_, err = r.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{
|
||||
Messages: payload,
|
||||
Skills: normalizedSkills,
|
||||
Timestamp: pgtype.Timestamptz{
|
||||
Time: time.Now().UTC(),
|
||||
Valid: true,
|
||||
@@ -582,7 +642,7 @@ func (r *Resolver) tryStoreFromStreamPayload(ctx context.Context, userID, query,
|
||||
// Case 1: event: done + data: {messages: [...]}
|
||||
if eventType == "done" {
|
||||
if parsed, ok := parseGatewayResponse([]byte(data)); ok {
|
||||
return r.storeRound(ctx, userID, query, parsed.Messages)
|
||||
return r.storeRound(ctx, userID, query, parsed.Messages, parsed.Skills)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -594,14 +654,14 @@ func (r *Resolver) tryStoreFromStreamPayload(ctx context.Context, userID, query,
|
||||
if err := json.Unmarshal([]byte(data), &envelope); err == nil {
|
||||
if envelope.Type == "done" && len(envelope.Data) > 0 {
|
||||
if parsed, ok := parseGatewayResponse(envelope.Data); ok {
|
||||
return r.storeRound(ctx, userID, query, parsed.Messages)
|
||||
return r.storeRound(ctx, userID, query, parsed.Messages, parsed.Skills)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: data: {messages:[...]} without event
|
||||
if parsed, ok := parseGatewayResponse([]byte(data)); ok {
|
||||
return r.storeRound(ctx, userID, query, parsed.Messages)
|
||||
return r.storeRound(ctx, userID, query, parsed.Messages, parsed.Skills)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -617,8 +677,8 @@ func parseGatewayResponse(payload []byte) (agentGatewayResponse, bool) {
|
||||
return parsed, true
|
||||
}
|
||||
|
||||
func (r *Resolver) storeRound(ctx context.Context, userID, query string, messages []GatewayMessage) (bool, error) {
|
||||
if err := r.storeHistory(ctx, userID, query, messages); err != nil {
|
||||
func (r *Resolver) storeRound(ctx context.Context, userID, query string, messages []GatewayMessage, skills []string) (bool, error) {
|
||||
if err := r.storeHistory(ctx, userID, query, messages, skills); err != nil {
|
||||
return true, err
|
||||
}
|
||||
if err := r.storeMemory(ctx, userID, query, messages); err != nil {
|
||||
@@ -627,6 +687,23 @@ func (r *Resolver) storeRound(ctx context.Context, userID, query string, message
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func normalizeSkills(skills []string) []string {
|
||||
seen := map[string]struct{}{}
|
||||
normalized := make([]string, 0, len(skills))
|
||||
for _, skill := range skills {
|
||||
trimmed := strings.TrimSpace(skill)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func gatewayMessageToMemory(msg GatewayMessage) (string, string) {
|
||||
role := "assistant"
|
||||
if raw, ok := msg["role"].(string); ok && strings.TrimSpace(raw) != "" {
|
||||
|
||||
@@ -9,6 +9,12 @@ type Message struct {
|
||||
|
||||
type GatewayMessage map[string]interface{}
|
||||
|
||||
type AgentSkill struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
UserID string `json:"-"`
|
||||
Token string `json:"-"`
|
||||
@@ -22,10 +28,13 @@ type ChatRequest struct {
|
||||
Platforms []string `json:"platforms,omitempty"`
|
||||
CurrentPlatform string `json:"current_platform,omitempty"`
|
||||
Messages []GatewayMessage `json:"messages,omitempty"`
|
||||
Skills []AgentSkill `json:"skills,omitempty"`
|
||||
UseSkills []string `json:"use_skills,omitempty"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Messages []GatewayMessage `json:"messages"`
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
@@ -12,23 +12,30 @@ import (
|
||||
)
|
||||
|
||||
const createHistory = `-- name: CreateHistory :one
|
||||
INSERT INTO history (messages, timestamp, "user")
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, messages, timestamp, "user"
|
||||
INSERT INTO history (messages, skills, timestamp, "user")
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, messages, skills, timestamp, "user"
|
||||
`
|
||||
|
||||
type CreateHistoryParams struct {
|
||||
Messages []byte `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Timestamp pgtype.Timestamptz `json:"timestamp"`
|
||||
User pgtype.UUID `json:"user"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (History, error) {
|
||||
row := q.db.QueryRow(ctx, createHistory, arg.Messages, arg.Timestamp, arg.User)
|
||||
row := q.db.QueryRow(ctx, createHistory,
|
||||
arg.Messages,
|
||||
arg.Skills,
|
||||
arg.Timestamp,
|
||||
arg.User,
|
||||
)
|
||||
var i History
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Messages,
|
||||
&i.Skills,
|
||||
&i.Timestamp,
|
||||
&i.User,
|
||||
)
|
||||
@@ -56,7 +63,7 @@ func (q *Queries) DeleteHistoryByUser(ctx context.Context, user pgtype.UUID) err
|
||||
}
|
||||
|
||||
const getHistoryByID = `-- name: GetHistoryByID :one
|
||||
SELECT id, messages, timestamp, "user"
|
||||
SELECT id, messages, skills, timestamp, "user"
|
||||
FROM history
|
||||
WHERE id = $1
|
||||
`
|
||||
@@ -67,6 +74,7 @@ func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History,
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Messages,
|
||||
&i.Skills,
|
||||
&i.Timestamp,
|
||||
&i.User,
|
||||
)
|
||||
@@ -74,7 +82,7 @@ func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History,
|
||||
}
|
||||
|
||||
const listHistoryByUser = `-- name: ListHistoryByUser :many
|
||||
SELECT id, messages, timestamp, "user"
|
||||
SELECT id, messages, skills, timestamp, "user"
|
||||
FROM history
|
||||
WHERE "user" = $1
|
||||
ORDER BY timestamp DESC
|
||||
@@ -98,6 +106,7 @@ func (q *Queries) ListHistoryByUser(ctx context.Context, arg ListHistoryByUserPa
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Messages,
|
||||
&i.Skills,
|
||||
&i.Timestamp,
|
||||
&i.User,
|
||||
); err != nil {
|
||||
@@ -112,7 +121,7 @@ func (q *Queries) ListHistoryByUser(ctx context.Context, arg ListHistoryByUserPa
|
||||
}
|
||||
|
||||
const listHistoryByUserSince = `-- name: ListHistoryByUserSince :many
|
||||
SELECT id, messages, timestamp, "user"
|
||||
SELECT id, messages, skills, timestamp, "user"
|
||||
FROM history
|
||||
WHERE "user" = $1 AND timestamp >= $2
|
||||
ORDER BY timestamp ASC
|
||||
@@ -135,6 +144,7 @@ func (q *Queries) ListHistoryByUserSince(ctx context.Context, arg ListHistoryByU
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Messages,
|
||||
&i.Skills,
|
||||
&i.Timestamp,
|
||||
&i.User,
|
||||
); err != nil {
|
||||
|
||||
@@ -36,6 +36,7 @@ type ContainerVersion struct {
|
||||
type History struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Messages []byte `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Timestamp pgtype.Timestamptz `json:"timestamp"`
|
||||
User pgtype.UUID `json:"user"`
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ func (s *Service) Create(ctx context.Context, userID string, req CreateRequest)
|
||||
}
|
||||
row, err := s.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{
|
||||
Messages: payload,
|
||||
Skills: normalizeSkills(req.Skills),
|
||||
Timestamp: pgtype.Timestamptz{
|
||||
Time: time.Now().UTC(),
|
||||
Valid: true,
|
||||
@@ -122,6 +123,7 @@ func toRecord(row sqlc.History) (Record, error) {
|
||||
}
|
||||
record := Record{
|
||||
Messages: messages,
|
||||
Skills: normalizeSkills(row.Skills),
|
||||
}
|
||||
if row.Timestamp.Valid {
|
||||
record.Timestamp = row.Timestamp.Time
|
||||
@@ -141,6 +143,23 @@ func toRecord(row sqlc.History) (Record, error) {
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func normalizeSkills(skills []string) []string {
|
||||
seen := map[string]struct{}{}
|
||||
normalized := make([]string, 0, len(skills))
|
||||
for _, skill := range skills {
|
||||
trimmed := strings.TrimSpace(skill)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[trimmed]; ok {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func parseUUID(id string) (pgtype.UUID, error) {
|
||||
parsed, err := uuid.Parse(strings.TrimSpace(id))
|
||||
if err != nil {
|
||||
|
||||
@@ -5,12 +5,14 @@ import "time"
|
||||
type Record struct {
|
||||
ID string `json:"id"`
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
}
|
||||
|
||||
type ListResponse struct {
|
||||
|
||||
Reference in New Issue
Block a user