diff --git a/agent/src/agent.ts b/agent/src/agent.ts index 1751e729..aeb053ee 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -1,12 +1,13 @@ import { generateText, ModelMessage, stepCountIs, streamText, TextStreamPart, ToolSet } from 'ai' import { createChatGateway } from './gateway' -import { BaseModelConfig, Schedule } from './types' +import { AgentSkill, BaseModelConfig, Schedule } from './types' import { system, schedule } from './prompts' import { AuthFetcher } from './index' import { getScheduleTools } from './tools/schedule' import { getWebTools } from './tools/web' import { subagentSystem } from './prompts/subagent' import { getSubagentTools } from './tools/subagent' +import { getSkillTools } from './tools/skill' export enum AgentAction { WebSearch = 'web_search', @@ -25,6 +26,8 @@ export interface AgentParams extends BaseModelConfig { currentPlatform?: string braveApiKey?: string braveBaseUrl?: string + skills?: AgentSkill[] + useSkills?: string[] allowed?: AgentAction[] } @@ -35,6 +38,7 @@ export interface AgentInput { export interface AgentResult { messages: ModelMessage[] + skills: string[] } export const createAgent = ( @@ -43,6 +47,10 @@ export const createAgent = ( ) => { const gateway = createChatGateway(params.clientType) const messages: ModelMessage[] = [] + const enabledSkills: AgentSkill[] = params.skills ?? [] + enabledSkills.push( + ...params.useSkills?.map((name) => params.skills?.find((s) => s.name === name) + ).filter((s) => s !== undefined) ?? []) const allowedActions = params.allowed ?? Object.values(AgentAction) @@ -52,6 +60,19 @@ export const createAgent = ( const getTools = () => { const tools: ToolSet = {} + if (allowedActions.includes(AgentAction.Skill)) { + const skillTools = getSkillTools({ + skills: params.skills ?? [], + useSkill: (skill) => { + if (enabledSkills.some((s) => s.name === skill.name)) { + return + } + enabledSkills.push(skill) + } + }) + Object.assign(tools, skillTools) + } + if (allowedActions.includes(AgentAction.Schedule)) { const scheduleTools = getScheduleTools({ fetch: fetcher }) Object.assign(tools, scheduleTools) @@ -89,6 +110,8 @@ export const createAgent = ( maxContextLoadTime: params.maxContextLoadTime ?? 1550, platforms: params.platforms ?? [], currentPlatform: params.currentPlatform, + skills: params.skills ?? [], + enabledSkills, }) } @@ -107,10 +130,16 @@ export const createAgent = ( system: generateSystem(), stopWhen: stepCountIs(maxSteps), messages, + prepareStep: () => { + return { + system: generateSystem(), + } + }, tools: getTools(), }) return { messages: [user, ...response.messages], + skills: enabledSkills.map((s) => s.name), } } @@ -144,6 +173,7 @@ export const createAgent = ( }) return { messages: [user, ...response.messages], + skills: enabledSkills.map((s) => s.name), } } @@ -162,6 +192,11 @@ export const createAgent = ( system: generateSystem(), stopWhen: stepCountIs(maxSteps), messages, + prepareStep: () => { + return { + system: generateSystem(), + } + }, tools: getTools(), }) for await (const event of fullStream) { @@ -169,6 +204,7 @@ export const createAgent = ( } return { messages: [user, ...(await response).messages], + skills: enabledSkills.map((s) => s.name), } } @@ -194,10 +230,16 @@ export const createAgent = ( system: generateSystem(), stopWhen: stepCountIs(maxSteps), messages, + prepareStep: () => { + return { + system: generateSystem(), + } + }, tools: getTools(), }) return { messages: [user, ...response.messages], + skills: enabledSkills.map((s) => s.name), } } diff --git a/agent/src/modules/chat.ts b/agent/src/modules/chat.ts index 5aaadef4..67e53d41 100644 --- a/agent/src/modules/chat.ts +++ b/agent/src/modules/chat.ts @@ -7,6 +7,12 @@ import { ModelMessage } from 'ai' import { bearerMiddleware } from '../middlewares/bearer' import { loadConfig } from '../config' +const Skill = z.object({ + name: z.string().min(1, 'Skill name is required'), + description: z.string().min(1, 'Skill description is required'), + content: z.string().min(1, 'Skill content is required'), +}) + const ChatBody = z.object({ apiKey: z.string().min(1, 'API key is required'), baseUrl: z.string().min(1, 'Base URL is required'), @@ -22,6 +28,8 @@ const ChatBody = z.object({ maxContextLoadTime: z.number().min(1, 'Max context load time is required'), platforms: z.array(z.string()).optional(), currentPlatform: z.string().optional(), + skills: z.array(Skill).optional(), + useSkills: z.array(z.string()).optional(), messages: z.array(z.any()), query: z.string().min(1, 'Query is required'), @@ -56,6 +64,8 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentPlatform: body.currentPlatform, braveApiKey: config.brave?.api_key, braveBaseUrl: config.brave?.base_url, + skills: body.skills, + useSkills: body.useSkills, }, createAuthFetcher(bearer)) try { const result = await ask({ @@ -98,6 +108,8 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentPlatform: body.currentPlatform, braveApiKey: config.brave?.api_key, braveBaseUrl: config.brave?.base_url, + skills: body.skills, + useSkills: body.useSkills, }, createAuthFetcher(bearer)) try { const streanGenerator = stream({ @@ -151,6 +163,8 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentPlatform: body.currentPlatform, braveApiKey: config.brave?.api_key, braveBaseUrl: config.brave?.base_url, + skills: body.skills, + useSkills: body.useSkills, }, createAuthFetcher(bearer)) try { return await triggerSchedule({ diff --git a/agent/src/prompts/system.ts b/agent/src/prompts/system.ts index 4cef169c..16728d18 100644 --- a/agent/src/prompts/system.ts +++ b/agent/src/prompts/system.ts @@ -1,5 +1,6 @@ import { time } from './shared' import { quote } from './utils' +import { AgentSkill } from '../types' export interface SystemParams { date: Date @@ -8,9 +9,20 @@ export interface SystemParams { maxContextLoadTime: number platforms: string[] currentPlatform?: string + skills: AgentSkill[] + enabledSkills: AgentSkill[] } -export const system = ({ date, locale, language, maxContextLoadTime, platforms, currentPlatform }: SystemParams) => { +export const skillPrompt = (skill: AgentSkill) => { + return ` +### ${skill.name} +> ${skill.description} + +${skill.content} + `.trim() +} + +export const system = ({ date, locale, language, maxContextLoadTime, platforms, currentPlatform, skills, enabledSkills }: SystemParams) => { return ` --- ${time({ date, locale })} @@ -57,5 +69,14 @@ When a task is large, you can create a Subagent to help you complete some tasks + The ${quote('name')} is the name of the subagent to ask. + The ${quote('query')} is the prompt to ask the subagent to complete the task. Before asking a subagent, you should first create a subagent if it does not exist. + +**Skills** + +There are ${skills.length} skills available, you can use ${quote('use_skill')} to use a skill. +${skills.map(skill => `- ${skill.name}: ${skill.description}`).join('\n')} + +**Enabled Skills** + +${enabledSkills.map(skill => skillPrompt(skill)).join('\n\n---\n\n')} `.trim() } \ No newline at end of file diff --git a/agent/src/tools/skill.ts b/agent/src/tools/skill.ts new file mode 100644 index 00000000..f669d074 --- /dev/null +++ b/agent/src/tools/skill.ts @@ -0,0 +1,34 @@ +import { AgentSkill } from '../types' +import { tool } from 'ai' +import { z } from 'zod' + +interface SkillToolParams { + skills: AgentSkill[] + useSkill: (skill: AgentSkill, reason: string) => void +} + +export const getSkillTools = ({ skills, useSkill }: SkillToolParams) => { + const useSkillTool = tool({ + description: 'Use a skill if you think it is relevant to the current task', + inputSchema: z.object({ + skillName: z.string().describe('The name of the skill to use'), + reason: z.string().describe('The reason why you think this skill is relevant to the current task'), + }), + execute: async ({ skillName, reason }) => { + const skill = skills.find((s) => s.name === skillName) + if (!skill) { + return { error: 'Skill not found' } + } + await useSkill(skill, reason) + return { + success: true, + skillName, + reason, + } + }, + }) + + return { + 'use_skill': useSkillTool, + } +} \ No newline at end of file diff --git a/agent/src/types.ts b/agent/src/types.ts index 75820d86..28d068ad 100644 --- a/agent/src/types.ts +++ b/agent/src/types.ts @@ -18,4 +18,10 @@ export interface Schedule { pattern: string maxCalls?: number command: string +} + +export interface AgentSkill { + name: string + description: string + content: string } \ No newline at end of file diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 49d1f011..0ca9b491 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -134,6 +134,7 @@ CREATE INDEX IF NOT EXISTS idx_lifecycle_events_event_type ON lifecycle_events(e CREATE TABLE IF NOT EXISTS history ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), messages JSONB NOT NULL, + skills TEXT[] NOT NULL DEFAULT '{}'::text[], timestamp TIMESTAMPTZ NOT NULL, "user" UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE ); diff --git a/db/queries/history.sql b/db/queries/history.sql index bbb00ec8..c2392ced 100644 --- a/db/queries/history.sql +++ b/db/queries/history.sql @@ -1,21 +1,21 @@ -- name: CreateHistory :one -INSERT INTO history (messages, timestamp, "user") -VALUES ($1, $2, $3) -RETURNING id, messages, timestamp, "user"; +INSERT INTO history (messages, skills, timestamp, "user") +VALUES ($1, $2, $3, $4) +RETURNING id, messages, skills, timestamp, "user"; -- name: ListHistoryByUserSince :many -SELECT id, messages, timestamp, "user" +SELECT id, messages, skills, timestamp, "user" FROM history WHERE "user" = $1 AND timestamp >= $2 ORDER BY timestamp ASC; -- name: GetHistoryByID :one -SELECT id, messages, timestamp, "user" +SELECT id, messages, skills, timestamp, "user" FROM history WHERE id = $1; -- name: ListHistoryByUser :many -SELECT id, messages, timestamp, "user" +SELECT id, messages, skills, timestamp, "user" FROM history WHERE "user" = $1 ORDER BY timestamp DESC diff --git a/docs/docs.go b/docs/docs.go index 94fdb3c3..d7c4f6c6 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -2594,6 +2594,20 @@ const docTemplate = `{ } }, "definitions": { + "chat.AgentSkill": { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "description": { + "type": "string" + }, + "name": { + "type": "string" + } + } + }, "chat.ChatRequest": { "type": "object", "properties": { @@ -2632,6 +2646,18 @@ const docTemplate = `{ }, "query": { "type": "string" + }, + "skills": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.AgentSkill" + } + }, + "use_skills": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -2649,6 +2675,12 @@ const docTemplate = `{ }, "provider": { "type": "string" + }, + "skills": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -2981,6 +3013,12 @@ const docTemplate = `{ "type": "object", "additionalProperties": true } + }, + "skills": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -3008,6 +3046,12 @@ const docTemplate = `{ "additionalProperties": true } }, + "skills": { + "type": "array", + "items": { + "type": "string" + } + }, "timestamp": { "type": "string" }, diff --git a/docs/swagger.json b/docs/swagger.json index 99940e31..13920970 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -2585,6 +2585,20 @@ } }, "definitions": { + "chat.AgentSkill": { + "type": "object", + "properties": { + "content": { + "type": "string" + }, + "description": { + "type": "string" + }, + "name": { + "type": "string" + } + } + }, "chat.ChatRequest": { "type": "object", "properties": { @@ -2623,6 +2637,18 @@ }, "query": { "type": "string" + }, + "skills": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.AgentSkill" + } + }, + "use_skills": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -2640,6 +2666,12 @@ }, "provider": { "type": "string" + }, + "skills": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -2972,6 +3004,12 @@ "type": "object", "additionalProperties": true } + }, + "skills": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -2999,6 +3037,12 @@ "additionalProperties": true } }, + "skills": { + "type": "array", + "items": { + "type": "string" + } + }, "timestamp": { "type": "string" }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 4542091d..e6312f55 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,4 +1,13 @@ definitions: + chat.AgentSkill: + properties: + content: + type: string + description: + type: string + name: + type: string + type: object chat.ChatRequest: properties: current_platform: @@ -25,6 +34,14 @@ definitions: type: string query: type: string + skills: + items: + $ref: '#/definitions/chat.AgentSkill' + type: array + use_skills: + items: + type: string + type: array type: object chat.ChatResponse: properties: @@ -36,6 +53,10 @@ definitions: type: string provider: type: string + skills: + items: + type: string + type: array type: object chat.GatewayMessage: additionalProperties: true @@ -251,6 +272,10 @@ definitions: additionalProperties: true type: object type: array + skills: + items: + type: string + type: array type: object history.ListResponse: properties: @@ -268,6 +293,10 @@ definitions: additionalProperties: true type: object type: array + skills: + items: + type: string + type: array timestamp: type: string user_id: diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index e1030607..e51b95a8 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -88,16 +88,22 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err } var messages []GatewayMessage + var historySkills []string if !skipHistory { messages, err = r.loadHistoryMessages(ctx, req.UserID, maxContextLoadTime) if err != nil { return ChatResponse{}, err } + historySkills, err = r.loadHistorySkills(ctx, req.UserID, maxContextLoadTime) + if err != nil { + return ChatResponse{}, err + } } if len(req.Messages) > 0 { messages = append(messages, req.Messages...) } messages = sanitizeGatewayMessages(messages) + useSkills := normalizeSkills(append(historySkills, req.UseSkills...)) payload := agentGatewayRequest{ APIKey: provider.ApiKey, @@ -112,6 +118,8 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err CurrentPlatform: req.CurrentPlatform, Messages: messages, Query: req.Query, + Skills: req.Skills, + UseSkills: useSkills, } payload.Language = language @@ -120,7 +128,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err return ChatResponse{}, err } - if err := r.storeHistory(ctx, req.UserID, req.Query, resp.Messages); err != nil { + if err := r.storeHistory(ctx, req.UserID, req.Query, resp.Messages, resp.Skills); err != nil { return ChatResponse{}, err } if err := r.storeMemory(ctx, req.UserID, req.Query, resp.Messages); err != nil { @@ -129,6 +137,7 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err return ChatResponse{ Messages: resp.Messages, + Skills: resp.Skills, Model: chatModel.ModelID, Provider: provider.ClientType, }, nil @@ -166,6 +175,11 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule if err != nil { return err } + historySkills, err := r.loadHistorySkills(ctx, userID, maxContextLoadTime) + if err != nil { + return err + } + useSkills := normalizeSkills(historySkills) payload := agentGatewayScheduleRequest{ APIKey: provider.ApiKey, @@ -181,13 +195,14 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, userID string, schedule Messages: messages, Query: schedule.Command, Schedule: schedule, + UseSkills: useSkills, } resp, err := r.postSchedule(ctx, payload, token) if err != nil { return err } - if err := r.storeHistory(ctx, userID, schedule.Command, resp.Messages); err != nil { + if err := r.storeHistory(ctx, userID, schedule.Command, resp.Messages, resp.Skills); err != nil { return err } if err := r.storeMemory(ctx, userID, schedule.Command, resp.Messages); err != nil { @@ -238,17 +253,24 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre } var messages []GatewayMessage + var historySkills []string if !skipHistory { messages, err = r.loadHistoryMessages(ctx, req.UserID, maxContextLoadTime) if err != nil { errChan <- err return } + historySkills, err = r.loadHistorySkills(ctx, req.UserID, maxContextLoadTime) + if err != nil { + errChan <- err + return + } } if len(req.Messages) > 0 { messages = append(messages, req.Messages...) } messages = sanitizeGatewayMessages(messages) + useSkills := normalizeSkills(append(historySkills, req.UseSkills...)) payload := agentGatewayRequest{ APIKey: provider.ApiKey, @@ -263,6 +285,8 @@ func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan Stre CurrentPlatform: req.CurrentPlatform, Messages: messages, Query: req.Query, + Skills: req.Skills, + UseSkills: useSkills, } payload.Language = language @@ -288,6 +312,8 @@ type agentGatewayRequest struct { CurrentPlatform string `json:"currentPlatform,omitempty"` Messages []GatewayMessage `json:"messages"` Query string `json:"query"` + Skills []AgentSkill `json:"skills,omitempty"` + UseSkills []string `json:"useSkills,omitempty"` } type agentGatewayScheduleRequest struct { @@ -304,10 +330,13 @@ type agentGatewayScheduleRequest struct { Messages []GatewayMessage `json:"messages"` Query string `json:"query"` Schedule SchedulePayload `json:"schedule"` + Skills []AgentSkill `json:"skills,omitempty"` + UseSkills []string `json:"useSkills,omitempty"` } type agentGatewayResponse struct { Messages []GatewayMessage `json:"messages"` + Skills []string `json:"skills"` } func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest, token string) (agentGatewayResponse, error) { @@ -475,7 +504,36 @@ func (r *Resolver) loadHistoryMessages(ctx context.Context, userID string, maxCo return messages, nil } -func (r *Resolver) storeHistory(ctx context.Context, userID, query string, responseMessages []GatewayMessage) error { +func (r *Resolver) loadHistorySkills(ctx context.Context, userID string, maxContextLoadTime int) ([]string, error) { + if r.queries == nil { + return nil, fmt.Errorf("history queries not configured") + } + pgUserID, err := parseUUID(userID) + if err != nil { + return nil, err + } + from := time.Now().UTC().Add(-time.Duration(normalizeMaxContextLoad(maxContextLoadTime)) * time.Minute) + rows, err := r.queries.ListHistoryByUserSince(ctx, sqlc.ListHistoryByUserSinceParams{ + User: pgUserID, + Timestamp: pgtype.Timestamptz{ + Time: from, + Valid: true, + }, + }) + if err != nil { + return nil, err + } + combined := make([]string, 0, len(rows)) + for _, row := range rows { + if len(row.Skills) == 0 { + continue + } + combined = append(combined, row.Skills...) + } + return normalizeSkills(combined), nil +} + +func (r *Resolver) storeHistory(ctx context.Context, userID, query string, responseMessages []GatewayMessage, skills []string) error { if r.queries == nil { return fmt.Errorf("history queries not configured") } @@ -496,8 +554,10 @@ func (r *Resolver) storeHistory(ctx context.Context, userID, query string, respo if err := r.ensureUserExists(ctx, pgUserID); err != nil { return err } + normalizedSkills := normalizeSkills(skills) _, err = r.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{ Messages: payload, + Skills: normalizedSkills, Timestamp: pgtype.Timestamptz{ Time: time.Now().UTC(), Valid: true, @@ -582,7 +642,7 @@ func (r *Resolver) tryStoreFromStreamPayload(ctx context.Context, userID, query, // Case 1: event: done + data: {messages: [...]} if eventType == "done" { if parsed, ok := parseGatewayResponse([]byte(data)); ok { - return r.storeRound(ctx, userID, query, parsed.Messages) + return r.storeRound(ctx, userID, query, parsed.Messages, parsed.Skills) } } @@ -594,14 +654,14 @@ func (r *Resolver) tryStoreFromStreamPayload(ctx context.Context, userID, query, if err := json.Unmarshal([]byte(data), &envelope); err == nil { if envelope.Type == "done" && len(envelope.Data) > 0 { if parsed, ok := parseGatewayResponse(envelope.Data); ok { - return r.storeRound(ctx, userID, query, parsed.Messages) + return r.storeRound(ctx, userID, query, parsed.Messages, parsed.Skills) } } } // Case 3: data: {messages:[...]} without event if parsed, ok := parseGatewayResponse([]byte(data)); ok { - return r.storeRound(ctx, userID, query, parsed.Messages) + return r.storeRound(ctx, userID, query, parsed.Messages, parsed.Skills) } return false, nil } @@ -617,8 +677,8 @@ func parseGatewayResponse(payload []byte) (agentGatewayResponse, bool) { return parsed, true } -func (r *Resolver) storeRound(ctx context.Context, userID, query string, messages []GatewayMessage) (bool, error) { - if err := r.storeHistory(ctx, userID, query, messages); err != nil { +func (r *Resolver) storeRound(ctx context.Context, userID, query string, messages []GatewayMessage, skills []string) (bool, error) { + if err := r.storeHistory(ctx, userID, query, messages, skills); err != nil { return true, err } if err := r.storeMemory(ctx, userID, query, messages); err != nil { @@ -627,6 +687,23 @@ func (r *Resolver) storeRound(ctx context.Context, userID, query string, message return true, nil } +func normalizeSkills(skills []string) []string { + seen := map[string]struct{}{} + normalized := make([]string, 0, len(skills)) + for _, skill := range skills { + trimmed := strings.TrimSpace(skill) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + normalized = append(normalized, trimmed) + } + return normalized +} + func gatewayMessageToMemory(msg GatewayMessage) (string, string) { role := "assistant" if raw, ok := msg["role"].(string); ok && strings.TrimSpace(raw) != "" { diff --git a/internal/chat/types.go b/internal/chat/types.go index 49f62dbc..117ad40e 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -9,6 +9,12 @@ type Message struct { type GatewayMessage map[string]interface{} +type AgentSkill struct { + Name string `json:"name"` + Description string `json:"description"` + Content string `json:"content"` +} + type ChatRequest struct { UserID string `json:"-"` Token string `json:"-"` @@ -22,10 +28,13 @@ type ChatRequest struct { Platforms []string `json:"platforms,omitempty"` CurrentPlatform string `json:"current_platform,omitempty"` Messages []GatewayMessage `json:"messages,omitempty"` + Skills []AgentSkill `json:"skills,omitempty"` + UseSkills []string `json:"use_skills,omitempty"` } type ChatResponse struct { Messages []GatewayMessage `json:"messages"` + Skills []string `json:"skills,omitempty"` Model string `json:"model,omitempty"` Provider string `json:"provider,omitempty"` } diff --git a/internal/db/sqlc/history.sql.go b/internal/db/sqlc/history.sql.go index 00f6f50b..83de32ba 100644 --- a/internal/db/sqlc/history.sql.go +++ b/internal/db/sqlc/history.sql.go @@ -12,23 +12,30 @@ import ( ) const createHistory = `-- name: CreateHistory :one -INSERT INTO history (messages, timestamp, "user") -VALUES ($1, $2, $3) -RETURNING id, messages, timestamp, "user" +INSERT INTO history (messages, skills, timestamp, "user") +VALUES ($1, $2, $3, $4) +RETURNING id, messages, skills, timestamp, "user" ` type CreateHistoryParams struct { Messages []byte `json:"messages"` + Skills []string `json:"skills"` Timestamp pgtype.Timestamptz `json:"timestamp"` User pgtype.UUID `json:"user"` } func (q *Queries) CreateHistory(ctx context.Context, arg CreateHistoryParams) (History, error) { - row := q.db.QueryRow(ctx, createHistory, arg.Messages, arg.Timestamp, arg.User) + row := q.db.QueryRow(ctx, createHistory, + arg.Messages, + arg.Skills, + arg.Timestamp, + arg.User, + ) var i History err := row.Scan( &i.ID, &i.Messages, + &i.Skills, &i.Timestamp, &i.User, ) @@ -56,7 +63,7 @@ func (q *Queries) DeleteHistoryByUser(ctx context.Context, user pgtype.UUID) err } const getHistoryByID = `-- name: GetHistoryByID :one -SELECT id, messages, timestamp, "user" +SELECT id, messages, skills, timestamp, "user" FROM history WHERE id = $1 ` @@ -67,6 +74,7 @@ func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History, err := row.Scan( &i.ID, &i.Messages, + &i.Skills, &i.Timestamp, &i.User, ) @@ -74,7 +82,7 @@ func (q *Queries) GetHistoryByID(ctx context.Context, id pgtype.UUID) (History, } const listHistoryByUser = `-- name: ListHistoryByUser :many -SELECT id, messages, timestamp, "user" +SELECT id, messages, skills, timestamp, "user" FROM history WHERE "user" = $1 ORDER BY timestamp DESC @@ -98,6 +106,7 @@ func (q *Queries) ListHistoryByUser(ctx context.Context, arg ListHistoryByUserPa if err := rows.Scan( &i.ID, &i.Messages, + &i.Skills, &i.Timestamp, &i.User, ); err != nil { @@ -112,7 +121,7 @@ func (q *Queries) ListHistoryByUser(ctx context.Context, arg ListHistoryByUserPa } const listHistoryByUserSince = `-- name: ListHistoryByUserSince :many -SELECT id, messages, timestamp, "user" +SELECT id, messages, skills, timestamp, "user" FROM history WHERE "user" = $1 AND timestamp >= $2 ORDER BY timestamp ASC @@ -135,6 +144,7 @@ func (q *Queries) ListHistoryByUserSince(ctx context.Context, arg ListHistoryByU if err := rows.Scan( &i.ID, &i.Messages, + &i.Skills, &i.Timestamp, &i.User, ); err != nil { diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 42146697..5271b577 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -36,6 +36,7 @@ type ContainerVersion struct { type History struct { ID pgtype.UUID `json:"id"` Messages []byte `json:"messages"` + Skills []string `json:"skills"` Timestamp pgtype.Timestamptz `json:"timestamp"` User pgtype.UUID `json:"user"` } diff --git a/internal/history/service.go b/internal/history/service.go index 0b860358..df99b714 100644 --- a/internal/history/service.go +++ b/internal/history/service.go @@ -44,6 +44,7 @@ func (s *Service) Create(ctx context.Context, userID string, req CreateRequest) } row, err := s.queries.CreateHistory(ctx, sqlc.CreateHistoryParams{ Messages: payload, + Skills: normalizeSkills(req.Skills), Timestamp: pgtype.Timestamptz{ Time: time.Now().UTC(), Valid: true, @@ -122,6 +123,7 @@ func toRecord(row sqlc.History) (Record, error) { } record := Record{ Messages: messages, + Skills: normalizeSkills(row.Skills), } if row.Timestamp.Valid { record.Timestamp = row.Timestamp.Time @@ -141,6 +143,23 @@ func toRecord(row sqlc.History) (Record, error) { return record, nil } +func normalizeSkills(skills []string) []string { + seen := map[string]struct{}{} + normalized := make([]string, 0, len(skills)) + for _, skill := range skills { + trimmed := strings.TrimSpace(skill) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + normalized = append(normalized, trimmed) + } + return normalized +} + func parseUUID(id string) (pgtype.UUID, error) { parsed, err := uuid.Parse(strings.TrimSpace(id)) if err != nil { diff --git a/internal/history/types.go b/internal/history/types.go index 72165388..bc25f2c4 100644 --- a/internal/history/types.go +++ b/internal/history/types.go @@ -5,12 +5,14 @@ import "time" type Record struct { ID string `json:"id"` Messages []map[string]interface{} `json:"messages"` + Skills []string `json:"skills"` Timestamp time.Time `json:"timestamp"` UserID string `json:"user_id"` } type CreateRequest struct { Messages []map[string]interface{} `json:"messages"` + Skills []string `json:"skills,omitempty"` } type ListResponse struct {