From a246b79a4fd81fe77ce8a14b1f32ab992352da07 Mon Sep 17 00:00:00 2001 From: BBQ Date: Fri, 6 Feb 2026 20:22:37 +0800 Subject: [PATCH] refactor: restructure channel gateway and chat module architecture - Refactor channel adapters (feishu, telegram, local) with enhanced descriptor and config - Restructure channel manager, service, types, and outbound messaging - Simplify chat module by removing normalize.go and chat.go, consolidating into resolver and types - Update router channel handlers and tests - Sync swagger documentation --- cmd/agent/main.go | 1 + db/queries/containers.sql | 18 + docs/docs.go | 238 ++-- docs/swagger.json | 238 ++-- docs/swagger.yaml | 238 ++-- internal/bots/service.go | 18 +- internal/bots/types.go | 10 +- internal/channel/adapter.go | 21 +- internal/channel/adapters/common/logging.go | 2 + internal/channel/adapters/feishu/config.go | 10 +- .../channel/adapters/feishu/descriptor.go | 2 + internal/channel/adapters/feishu/feishu.go | 91 +- internal/channel/adapters/local/cli.go | 4 + internal/channel/adapters/local/descriptor.go | 3 + internal/channel/adapters/local/web.go | 4 + internal/channel/adapters/telegram/config.go | 7 + .../channel/adapters/telegram/descriptor.go | 2 + .../channel/adapters/telegram/telegram.go | 63 +- internal/channel/capabilities.go | 4 +- internal/channel/cli_hub.go | 17 +- internal/channel/config.go | 20 +- internal/channel/directory.go | 4 + internal/channel/manager.go | 130 ++- internal/channel/outbound.go | 8 + internal/channel/processor.go | 2 +- internal/channel/registry.go | 11 + internal/channel/schema.go | 5 +- internal/channel/service.go | 50 +- internal/channel/target.go | 4 + internal/channel/types.go | 39 +- internal/chat/assistant_output.go | 54 +- internal/chat/chat.go | 1 - internal/chat/normalize.go | 407 ------- internal/chat/resolver.go | 1033 ++++++----------- internal/chat/schedule_gateway.go | 13 +- internal/chat/types.go | 161 ++- internal/db/sqlc/containers.sql.go | 47 + internal/handlers/containerd.go | 315 +++-- internal/handlers/fs.go | 7 +- internal/router/channel.go | 80 +- internal/router/channel_test.go | 44 +- mise.toml | 2 +- 42 files changed, 1683 insertions(+), 1745 deletions(-) delete mode 100644 internal/chat/chat.go delete mode 100644 internal/chat/normalize.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index adaef571..d740ddfd 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -99,6 +99,7 @@ func main() { usersService := users.NewService(logger.L, queries) containerdHandler := handlers.NewContainerdHandler(logger.L, service, cfg.MCP, cfg.Containerd.Namespace, botService, usersService, queries) + botService.SetContainerLifecycle(containerdHandler) if err := ensureAdminUser(ctx, logger.L, queries, cfg); err != nil { logger.Error("ensure admin user", slog.Any("error", err)) diff --git a/db/queries/containers.sql b/db/queries/containers.sql index 8d446b9f..3e81c964 100644 --- a/db/queries/containers.sql +++ b/db/queries/containers.sql @@ -34,3 +34,21 @@ SELECT * FROM containers WHERE container_id = sqlc.arg(container_id); -- name: GetContainerByBotID :one SELECT * FROM containers WHERE bot_id = sqlc.arg(bot_id) ORDER BY updated_at DESC LIMIT 1; + +-- name: DeleteContainerByBotID :exec +DELETE FROM containers WHERE bot_id = sqlc.arg(bot_id); + +-- name: UpdateContainerStatus :exec +UPDATE containers +SET status = sqlc.arg(status), updated_at = now() +WHERE bot_id = sqlc.arg(bot_id); + +-- name: UpdateContainerStarted :exec +UPDATE containers +SET status = 'running', last_started_at = now(), updated_at = now() +WHERE bot_id = sqlc.arg(bot_id); + +-- name: UpdateContainerStopped :exec +UPDATE containers +SET status = 'stopped', last_stopped_at = now(), updated_at = now() +WHERE bot_id = sqlc.arg(bot_id); diff --git a/docs/docs.go b/docs/docs.go index 3ef5efd5..947e88b9 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -241,6 +241,41 @@ const docTemplate = `{ } }, "/bots/{bot_id}/container": { + "get": { + "tags": [ + "containerd" + ], + "summary": "Get container info for bot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.GetContainerResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, "post": { "tags": [ "containerd" @@ -284,14 +319,12 @@ const docTemplate = `{ } } } - } - }, - "/bots/{bot_id}/container/list": { - "get": { + }, + "delete": { "tags": [ "containerd" ], - "summary": "List containers for bot", + "summary": "Delete MCP container for bot", "parameters": [ { "type": "string", @@ -302,10 +335,13 @@ const docTemplate = `{ } ], "responses": { - "200": { - "description": "OK", + "204": { + "description": "No Content" + }, + "404": { + "description": "Not Found", "schema": { - "$ref": "#/definitions/handlers.ListContainersResponse" + "$ref": "#/definitions/handlers.ErrorResponse" } }, "500": { @@ -351,7 +387,7 @@ const docTemplate = `{ "tags": [ "containerd" ], - "summary": "Create container snapshot", + "summary": "Create container snapshot for bot", "parameters": [ { "type": "string", @@ -376,16 +412,28 @@ const docTemplate = `{ "schema": { "$ref": "#/definitions/handlers.CreateSnapshotResponse" } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } } } } }, - "/bots/{bot_id}/container/{id}": { - "delete": { + "/bots/{bot_id}/container/start": { + "post": { "tags": [ "containerd" ], - "summary": "Delete MCP container", + "summary": "Start container task for bot", "parameters": [ { "type": "string", @@ -393,23 +441,50 @@ const docTemplate = `{ "name": "bot_id", "in": "path", "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "object" + } }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/stop": { + "post": { + "tags": [ + "containerd" + ], + "summary": "Stop container task for bot", + "parameters": [ { "type": "string", - "description": "Container ID", - "name": "id", + "description": "Bot ID", + "name": "bot_id", "in": "path", "required": true } ], "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", + "200": { + "description": "OK", "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" + "type": "object" } }, "404": { @@ -4494,7 +4569,7 @@ const docTemplate = `{ "messages": { "type": "array", "items": { - "$ref": "#/definitions/chat.GatewayMessage" + "$ref": "#/definitions/chat.ModelMessage" } }, "model": { @@ -4520,7 +4595,7 @@ const docTemplate = `{ "messages": { "type": "array", "items": { - "$ref": "#/definitions/chat.GatewayMessage" + "$ref": "#/definitions/chat.ModelMessage" } }, "model": { @@ -4537,9 +4612,56 @@ const docTemplate = `{ } } }, - "chat.GatewayMessage": { + "chat.ModelMessage": { "type": "object", - "additionalProperties": {} + "properties": { + "content": { + "type": "array", + "items": { + "type": "integer" + } + }, + "name": { + "type": "string" + }, + "role": { + "type": "string" + }, + "tool_call_id": { + "type": "string" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.ToolCall" + } + } + } + }, + "chat.ToolCall": { + "type": "object", + "properties": { + "function": { + "$ref": "#/definitions/chat.ToolCallFunction" + }, + "id": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "chat.ToolCallFunction": { + "type": "object", + "properties": { + "arguments": { + "type": "string" + }, + "name": { + "type": "string" + } + } }, "handlers.ChannelMeta": { "type": "object", @@ -4567,41 +4689,9 @@ const docTemplate = `{ } } }, - "handlers.ContainerInfo": { - "type": "object", - "properties": { - "created_at": { - "type": "string" - }, - "id": { - "type": "string" - }, - "image": { - "type": "string" - }, - "labels": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "snapshot_key": { - "type": "string" - }, - "snapshotter": { - "type": "string" - }, - "updated_at": { - "type": "string" - } - } - }, "handlers.CreateContainerRequest": { "type": "object", "properties": { - "container_id": { - "type": "string" - }, "image": { "type": "string" }, @@ -4630,9 +4720,6 @@ const docTemplate = `{ "handlers.CreateSnapshotRequest": { "type": "object", "properties": { - "container_id": { - "type": "string" - }, "snapshot_name": { "type": "string" } @@ -4748,14 +4835,35 @@ const docTemplate = `{ } } }, - "handlers.ListContainersResponse": { + "handlers.GetContainerResponse": { "type": "object", "properties": { - "containers": { - "type": "array", - "items": { - "$ref": "#/definitions/handlers.ContainerInfo" - } + "container_id": { + "type": "string" + }, + "container_path": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "host_path": { + "type": "string" + }, + "image": { + "type": "string" + }, + "namespace": { + "type": "string" + }, + "status": { + "type": "string" + }, + "task_running": { + "type": "boolean" + }, + "updated_at": { + "type": "string" } } }, diff --git a/docs/swagger.json b/docs/swagger.json index 16a8a22a..88734826 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -232,6 +232,41 @@ } }, "/bots/{bot_id}/container": { + "get": { + "tags": [ + "containerd" + ], + "summary": "Get container info for bot", + "parameters": [ + { + "type": "string", + "description": "Bot ID", + "name": "bot_id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/handlers.GetContainerResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, "post": { "tags": [ "containerd" @@ -275,14 +310,12 @@ } } } - } - }, - "/bots/{bot_id}/container/list": { - "get": { + }, + "delete": { "tags": [ "containerd" ], - "summary": "List containers for bot", + "summary": "Delete MCP container for bot", "parameters": [ { "type": "string", @@ -293,10 +326,13 @@ } ], "responses": { - "200": { - "description": "OK", + "204": { + "description": "No Content" + }, + "404": { + "description": "Not Found", "schema": { - "$ref": "#/definitions/handlers.ListContainersResponse" + "$ref": "#/definitions/handlers.ErrorResponse" } }, "500": { @@ -342,7 +378,7 @@ "tags": [ "containerd" ], - "summary": "Create container snapshot", + "summary": "Create container snapshot for bot", "parameters": [ { "type": "string", @@ -367,16 +403,28 @@ "schema": { "$ref": "#/definitions/handlers.CreateSnapshotResponse" } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } } } } }, - "/bots/{bot_id}/container/{id}": { - "delete": { + "/bots/{bot_id}/container/start": { + "post": { "tags": [ "containerd" ], - "summary": "Delete MCP container", + "summary": "Start container task for bot", "parameters": [ { "type": "string", @@ -384,23 +432,50 @@ "name": "bot_id", "in": "path", "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "object" + } }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/bots/{bot_id}/container/stop": { + "post": { + "tags": [ + "containerd" + ], + "summary": "Stop container task for bot", + "parameters": [ { "type": "string", - "description": "Container ID", - "name": "id", + "description": "Bot ID", + "name": "bot_id", "in": "path", "required": true } ], "responses": { - "204": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", + "200": { + "description": "OK", "schema": { - "$ref": "#/definitions/handlers.ErrorResponse" + "type": "object" } }, "404": { @@ -4485,7 +4560,7 @@ "messages": { "type": "array", "items": { - "$ref": "#/definitions/chat.GatewayMessage" + "$ref": "#/definitions/chat.ModelMessage" } }, "model": { @@ -4511,7 +4586,7 @@ "messages": { "type": "array", "items": { - "$ref": "#/definitions/chat.GatewayMessage" + "$ref": "#/definitions/chat.ModelMessage" } }, "model": { @@ -4528,9 +4603,56 @@ } } }, - "chat.GatewayMessage": { + "chat.ModelMessage": { "type": "object", - "additionalProperties": {} + "properties": { + "content": { + "type": "array", + "items": { + "type": "integer" + } + }, + "name": { + "type": "string" + }, + "role": { + "type": "string" + }, + "tool_call_id": { + "type": "string" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/definitions/chat.ToolCall" + } + } + } + }, + "chat.ToolCall": { + "type": "object", + "properties": { + "function": { + "$ref": "#/definitions/chat.ToolCallFunction" + }, + "id": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "chat.ToolCallFunction": { + "type": "object", + "properties": { + "arguments": { + "type": "string" + }, + "name": { + "type": "string" + } + } }, "handlers.ChannelMeta": { "type": "object", @@ -4558,41 +4680,9 @@ } } }, - "handlers.ContainerInfo": { - "type": "object", - "properties": { - "created_at": { - "type": "string" - }, - "id": { - "type": "string" - }, - "image": { - "type": "string" - }, - "labels": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "snapshot_key": { - "type": "string" - }, - "snapshotter": { - "type": "string" - }, - "updated_at": { - "type": "string" - } - } - }, "handlers.CreateContainerRequest": { "type": "object", "properties": { - "container_id": { - "type": "string" - }, "image": { "type": "string" }, @@ -4621,9 +4711,6 @@ "handlers.CreateSnapshotRequest": { "type": "object", "properties": { - "container_id": { - "type": "string" - }, "snapshot_name": { "type": "string" } @@ -4739,14 +4826,35 @@ } } }, - "handlers.ListContainersResponse": { + "handlers.GetContainerResponse": { "type": "object", "properties": { - "containers": { - "type": "array", - "items": { - "$ref": "#/definitions/handlers.ContainerInfo" - } + "container_id": { + "type": "string" + }, + "container_path": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "host_path": { + "type": "string" + }, + "image": { + "type": "string" + }, + "namespace": { + "type": "string" + }, + "status": { + "type": "string" + }, + "task_running": { + "type": "boolean" + }, + "updated_at": { + "type": "string" } } }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index e5e35f18..5eddb616 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -428,7 +428,7 @@ definitions: type: integer messages: items: - $ref: '#/definitions/chat.GatewayMessage' + $ref: '#/definitions/chat.ModelMessage' type: array model: type: string @@ -445,7 +445,7 @@ definitions: properties: messages: items: - $ref: '#/definitions/chat.GatewayMessage' + $ref: '#/definitions/chat.ModelMessage' type: array model: type: string @@ -456,8 +456,38 @@ definitions: type: string type: array type: object - chat.GatewayMessage: - additionalProperties: {} + chat.ModelMessage: + properties: + content: + items: + type: integer + type: array + name: + type: string + role: + type: string + tool_call_id: + type: string + tool_calls: + items: + $ref: '#/definitions/chat.ToolCall' + type: array + type: object + chat.ToolCall: + properties: + function: + $ref: '#/definitions/chat.ToolCallFunction' + id: + type: string + type: + type: string + type: object + chat.ToolCallFunction: + properties: + arguments: + type: string + name: + type: string type: object handlers.ChannelMeta: properties: @@ -476,29 +506,8 @@ definitions: user_config_schema: $ref: '#/definitions/channel.ConfigSchema' type: object - handlers.ContainerInfo: - properties: - created_at: - type: string - id: - type: string - image: - type: string - labels: - additionalProperties: - type: string - type: object - snapshot_key: - type: string - snapshotter: - type: string - updated_at: - type: string - type: object handlers.CreateContainerRequest: properties: - container_id: - type: string image: type: string snapshotter: @@ -517,8 +526,6 @@ definitions: type: object handlers.CreateSnapshotRequest: properties: - container_id: - type: string snapshot_name: type: string type: object @@ -593,12 +600,26 @@ definitions: message: type: string type: object - handlers.ListContainersResponse: + handlers.GetContainerResponse: properties: - containers: - items: - $ref: '#/definitions/handlers.ContainerInfo' - type: array + container_id: + type: string + container_path: + type: string + created_at: + type: string + host_path: + type: string + image: + type: string + namespace: + type: string + status: + type: string + task_running: + type: boolean + updated_at: + type: string type: object handlers.ListSnapshotsResponse: properties: @@ -1458,6 +1479,50 @@ paths: tags: - chat /bots/{bot_id}/container: + delete: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + responses: + "204": + description: No Content + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Delete MCP container for bot + tags: + - containerd + get: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/handlers.GetContainerResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get container info for bot + tags: + - containerd post: parameters: - description: Bot ID @@ -1487,57 +1552,6 @@ paths: summary: Create and start MCP container for bot tags: - containerd - /bots/{bot_id}/container/{id}: - delete: - parameters: - - description: Bot ID - in: path - name: bot_id - required: true - type: string - - description: Container ID - in: path - name: id - required: true - type: string - responses: - "204": - description: No Content - "400": - description: Bad Request - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "404": - description: Not Found - schema: - $ref: '#/definitions/handlers.ErrorResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: Delete MCP container - tags: - - containerd - /bots/{bot_id}/container/list: - get: - parameters: - - description: Bot ID - in: path - name: bot_id - required: true - type: string - responses: - "200": - description: OK - schema: - $ref: '#/definitions/handlers.ListContainersResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.ErrorResponse' - summary: List containers for bot - tags: - - containerd /bots/{bot_id}/container/snapshots: get: parameters: @@ -1576,7 +1590,63 @@ paths: description: OK schema: $ref: '#/definitions/handlers.CreateSnapshotResponse' - summary: Create container snapshot + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Create container snapshot for bot + tags: + - containerd + /bots/{bot_id}/container/start: + post: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + responses: + "200": + description: OK + schema: + type: object + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Start container task for bot + tags: + - containerd + /bots/{bot_id}/container/stop: + post: + parameters: + - description: Bot ID + in: path + name: bot_id + required: true + type: string + responses: + "200": + description: OK + schema: + type: object + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Stop container task for bot tags: - containerd /bots/{bot_id}/history: diff --git a/internal/bots/service.go b/internal/bots/service.go index d5526aaa..89fed6cd 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -17,8 +17,9 @@ import ( ) type Service struct { - queries *sqlc.Queries - logger *slog.Logger + queries *sqlc.Queries + logger *slog.Logger + containerLifecycle ContainerLifecycle } var ( @@ -40,6 +41,11 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } +// SetContainerLifecycle registers a container lifecycle handler for bot operations. +func (s *Service) SetContainerLifecycle(lc ContainerLifecycle) { + s.containerLifecycle = lc +} + func (s *Service) AuthorizeAccess(ctx context.Context, actorID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { if s.queries == nil { return Bot{}, fmt.Errorf("bot queries not configured") @@ -279,6 +285,14 @@ func (s *Service) Delete(ctx context.Context, botID string) error { if _, err := s.queries.GetBotByID(ctx, botUUID); err != nil { return err } + if s.containerLifecycle != nil { + if err := s.containerLifecycle.CleanupBotContainer(ctx, botID); err != nil { + s.logger.Error("failed to cleanup bot container", + slog.String("bot_id", botID), + slog.Any("error", err), + ) + } + } return s.queries.DeleteBotByID(ctx, botUUID) } diff --git a/internal/bots/types.go b/internal/bots/types.go index 288e7f64..8858c2e6 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -1,6 +1,9 @@ package bots -import "time" +import ( + "context" + "time" +) type Bot struct { ID string `json:"id"` @@ -53,6 +56,11 @@ type ListMembersResponse struct { Items []BotMember `json:"items"` } +// ContainerLifecycle handles container lifecycle events bound to bot operations. +type ContainerLifecycle interface { + CleanupBotContainer(ctx context.Context, botID string) error +} + const ( BotTypePersonal = "personal" BotTypePublic = "public" diff --git a/internal/channel/adapter.go b/internal/channel/adapter.go index 4d566b42..0cb9c0df 100644 --- a/internal/channel/adapter.go +++ b/internal/channel/adapter.go @@ -6,26 +6,33 @@ import ( "sync/atomic" ) +// ErrStopNotSupported is returned when a connection does not support graceful shutdown. var ErrStopNotSupported = errors.New("channel connection stop not supported") +// InboundHandler is a callback invoked when a message arrives from a channel. type InboundHandler func(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error +// ReplySender sends an outbound reply within the scope of a single inbound message. type ReplySender interface { Send(ctx context.Context, msg OutboundMessage) error } +// Adapter is the base interface every channel adapter must implement. type Adapter interface { Type() ChannelType } +// Sender is an adapter capable of sending outbound messages. type Sender interface { Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error } +// Receiver is an adapter capable of establishing a long-lived connection to receive messages. type Receiver interface { Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) } +// Connection represents an active, long-lived link to a channel platform. type Connection interface { ConfigID() string BotID() string @@ -34,6 +41,7 @@ type Connection interface { Running() bool } +// BaseConnection is a default Connection implementation backed by a stop function. type BaseConnection struct { configID string botID string @@ -42,6 +50,7 @@ type BaseConnection struct { running atomic.Bool } +// NewConnection creates a BaseConnection for the given config and stop function. func NewConnection(cfg ChannelConfig, stop func(ctx context.Context) error) *BaseConnection { conn := &BaseConnection{ configID: cfg.ID, @@ -53,29 +62,31 @@ func NewConnection(cfg ChannelConfig, stop func(ctx context.Context) error) *Bas return conn } +// ConfigID returns the channel configuration identifier. func (c *BaseConnection) ConfigID() string { return c.configID } +// BotID returns the bot identifier that owns this connection. func (c *BaseConnection) BotID() string { return c.botID } +// ChannelType returns the type of channel this connection serves. func (c *BaseConnection) ChannelType() ChannelType { return c.channelType } +// Stop gracefully shuts down the connection. func (c *BaseConnection) Stop(ctx context.Context) error { if c.stop == nil { return ErrStopNotSupported } - err := c.stop(ctx) - if err == nil { - c.running.Store(false) - } - return err + c.running.Store(false) + return c.stop(ctx) } +// Running reports whether the connection is still active. func (c *BaseConnection) Running() bool { return c.running.Load() } diff --git a/internal/channel/adapters/common/logging.go b/internal/channel/adapters/common/logging.go index 8c48f6a2..bc680421 100644 --- a/internal/channel/adapters/common/logging.go +++ b/internal/channel/adapters/common/logging.go @@ -1,7 +1,9 @@ +// Package common provides shared utilities for channel adapters. package common import "strings" +// SummarizeText returns a truncated preview of the text, limited to 120 characters. func SummarizeText(text string) string { value := strings.TrimSpace(text) if value == "" { diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go index da302857..fcd9dc1e 100644 --- a/internal/channel/adapters/feishu/config.go +++ b/internal/channel/adapters/feishu/config.go @@ -7,6 +7,7 @@ import ( "github.com/memohai/memoh/internal/channel" ) +// Config holds the Feishu app credentials extracted from a channel configuration. type Config struct { AppID string AppSecret string @@ -14,11 +15,13 @@ type Config struct { VerificationToken string } +// UserConfig holds the identifiers used to target a Feishu user. type UserConfig struct { OpenID string UserID string } +// NormalizeConfig validates and normalizes a Feishu channel configuration map. func NormalizeConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseConfig(raw) if err != nil { @@ -37,6 +40,7 @@ func NormalizeConfig(raw map[string]any) (map[string]any, error) { return result, nil } +// NormalizeUserConfig validates and normalizes a Feishu user-binding configuration map. func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseUserConfig(raw) if err != nil { @@ -52,6 +56,7 @@ func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return result, nil } +// ResolveTarget derives a Feishu delivery target from a user-binding configuration. func ResolveTarget(raw map[string]any) (string, error) { cfg, err := parseUserConfig(raw) if err != nil { @@ -66,6 +71,7 @@ func ResolveTarget(raw map[string]any) (string, error) { return "", fmt.Errorf("feishu binding is incomplete") } +// MatchBinding reports whether a Feishu user binding matches the given criteria. func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { cfg, err := parseUserConfig(raw) if err != nil { @@ -85,6 +91,7 @@ func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { return false } +// BuildUserConfig constructs a Feishu user-binding config from an Identity. func BuildUserConfig(identity channel.Identity) map[string]any { result := map[string]any{} if value := strings.TrimSpace(identity.Attribute("open_id")); value != "" { @@ -135,8 +142,5 @@ func normalizeTarget(raw string) string { if strings.HasPrefix(value, "oc_") { return "chat_id:" + value } - if strings.HasPrefix(value, "user_id:") { - return value - } return "open_id:" + value } diff --git a/internal/channel/adapters/feishu/descriptor.go b/internal/channel/adapters/feishu/descriptor.go index fdb9a5e5..b97be3c8 100644 --- a/internal/channel/adapters/feishu/descriptor.go +++ b/internal/channel/adapters/feishu/descriptor.go @@ -1,7 +1,9 @@ +// Package feishu implements the Feishu (Lark) channel adapter. package feishu import "github.com/memohai/memoh/internal/channel" +// Type is the registered ChannelType identifier for Feishu. const Type channel.ChannelType = "feishu" func init() { diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index dddfcb3d..6d9f1d9c 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -20,10 +20,12 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/common" ) +// FeishuAdapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Feishu. type FeishuAdapter struct { logger *slog.Logger } +// NewFeishuAdapter creates a FeishuAdapter with the given logger. func NewFeishuAdapter(log *slog.Logger) *FeishuAdapter { if log == nil { log = slog.Default() @@ -33,10 +35,12 @@ func NewFeishuAdapter(log *slog.Logger) *FeishuAdapter { } } +// Type returns the Feishu channel type. func (a *FeishuAdapter) Type() channel.ChannelType { return Type } +// Connect establishes a WebSocket connection to Feishu and forwards inbound messages to the handler. func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { a.logger.Info("start", slog.String("config_id", cfg.ID)) @@ -101,6 +105,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, return channel.NewConnection(cfg, stop), nil } +// Send delivers an outbound message to Feishu, handling attachments, rich text, and replies. func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { @@ -117,7 +122,6 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) - // 1. 处理附件 if len(msg.Message.Attachments) > 0 { for _, att := range msg.Message.Attachments { if err := a.sendAttachment(ctx, client, receiveID, receiveType, att, msg.Message.Text); err != nil { @@ -127,27 +131,29 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg return nil } - // 2. 处理富文本或普通文本 var msgType string var content string if len(msg.Message.Parts) > 1 { msgType = larkim.MsgTypePost - content, err = a.buildPostContent(msg.Message) + postContent, postErr := a.buildPostContent(msg.Message) + if postErr != nil { + return postErr + } + content = postContent } else { msgType = larkim.MsgTypeText text := strings.TrimSpace(msg.Message.PlainText()) if text == "" { return fmt.Errorf("message is required") } - payload, _ := json.Marshal(map[string]string{"text": text}) + payload, marshalErr := json.Marshal(map[string]string{"text": text}) + if marshalErr != nil { + return fmt.Errorf("failed to marshal text content: %w", marshalErr) + } content = string(payload) } - if err != nil { - return err - } - reqBuilder := larkim.NewCreateMessageReqBodyBuilder(). ReceiveId(receiveID). MsgType(msgType). @@ -159,7 +165,6 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg Body(reqBuilder.Build()). Build() - // 处理回复 if msg.Message.Reply != nil && msg.Message.Reply.MessageID != "" { replyReq := larkim.NewReplyMessageReqBuilder(). MessageId(msg.Message.Reply.MessageID). @@ -228,8 +233,12 @@ func (a *FeishuAdapter) handleResponse(configID string, resp *larkim.CreateMessa } func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, receiveID, receiveType string, att channel.Attachment, text string) error { - // 下载文件 - resp, err := http.Get(att.URL) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, att.URL, nil) + if err != nil { + return fmt.Errorf("failed to build download request: %w", err) + } + httpClient := &http.Client{Timeout: 60 * time.Second} + resp, err := httpClient.Do(httpReq) if err != nil { return fmt.Errorf("failed to download attachment: %w", err) } @@ -243,7 +252,6 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, var contentMap map[string]string if strings.HasPrefix(att.Mime, "image/") || att.Type == channel.AttachmentImage { - // 上传图片 uploadReq := larkim.NewCreateImageReqBuilder(). Body(larkim.NewCreateImageReqBodyBuilder(). ImageType(larkim.ImageTypeMessage). @@ -251,29 +259,46 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, Build()). Build() uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) - if err != nil || !uploadResp.Success() { + if err != nil { return fmt.Errorf("failed to upload image: %w", err) } + if uploadResp == nil || !uploadResp.Success() { + code, msg := 0, "" + if uploadResp != nil { + code, msg = uploadResp.Code, uploadResp.Msg + } + return fmt.Errorf("failed to upload image: %s (code: %d)", msg, code) + } msgType = larkim.MsgTypeImage contentMap = map[string]string{"image_key": *uploadResp.Data.ImageKey} } else { - // 上传文件 + fileType := resolveFeishuFileType(att.Name, att.Mime) uploadReq := larkim.NewCreateFileReqBuilder(). Body(larkim.NewCreateFileReqBodyBuilder(). - FileType(larkim.FileTypePdf). // 默认为 pdf,飞书支持 mp4, doc, xls, ppt, pdf, zip + FileType(fileType). FileName(att.Name). File(resp.Body). Build()). Build() uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) - if err != nil || !uploadResp.Success() { + if err != nil { return fmt.Errorf("failed to upload file: %w", err) } + if uploadResp == nil || !uploadResp.Success() { + code, msg := 0, "" + if uploadResp != nil { + code, msg = uploadResp.Code, uploadResp.Msg + } + return fmt.Errorf("failed to upload file: %s (code: %d)", msg, code) + } msgType = larkim.MsgTypeFile contentMap = map[string]string{"file_key": *uploadResp.Data.FileKey} } - content, _ := json.Marshal(contentMap) + content, err := json.Marshal(contentMap) + if err != nil { + return fmt.Errorf("failed to marshal content: %w", err) + } req := larkim.NewCreateMessageReqBuilder(). ReceiveIdType(receiveType). Body(larkim.NewCreateMessageReqBodyBuilder(). @@ -284,12 +309,32 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, Build()). Build() - _, err = client.Im.V1.Message.Create(ctx, req) - return err + sendResp, err := client.Im.V1.Message.Create(ctx, req) + return a.handleResponse("", sendResp, err) +} + +// resolveFeishuFileType maps MIME type and filename to a Feishu file type constant. +func resolveFeishuFileType(name, mime string) string { + lower := strings.ToLower(mime) + switch { + case strings.Contains(lower, "mp4") || strings.Contains(lower, "video"): + return larkim.FileTypeMp4 + case strings.Contains(lower, "pdf"): + return larkim.FileTypePdf + case strings.Contains(lower, "word") || strings.Contains(lower, "msword") || strings.HasSuffix(strings.ToLower(name), ".doc") || strings.HasSuffix(strings.ToLower(name), ".docx"): + return larkim.FileTypeDoc + case strings.Contains(lower, "excel") || strings.Contains(lower, "spreadsheet") || strings.HasSuffix(strings.ToLower(name), ".xls") || strings.HasSuffix(strings.ToLower(name), ".xlsx"): + return larkim.FileTypeXls + case strings.Contains(lower, "powerpoint") || strings.Contains(lower, "presentation") || strings.HasSuffix(strings.ToLower(name), ".ppt") || strings.HasSuffix(strings.ToLower(name), ".pptx"): + return larkim.FileTypePpt + case strings.Contains(lower, "zip") || strings.Contains(lower, "compressed") || strings.Contains(lower, "archive"): + return "zip" + default: + return larkim.FileTypePdf + } } func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { - // 简单的 Post 构建逻辑 type postContent struct { ZhCn struct { Title string `json:"title"` @@ -298,7 +343,7 @@ func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { } pc := postContent{} - pc.ZhCn.Title = "" // 暂时不设标题 + pc.ZhCn.Title = "" line := []any{} for _, part := range msg.Parts { @@ -326,7 +371,6 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa msg.ID = *message.MessageId } - // 解析内容 var contentMap map[string]any if message.Content != nil { _ = json.Unmarshal([]byte(*message.Content), &contentMap) @@ -342,7 +386,7 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa if key, ok := contentMap["image_key"].(string); ok { msg.Attachments = append(msg.Attachments, channel.Attachment{ Type: channel.AttachmentImage, - URL: key, // 飞书内部 key,上层需注意 + URL: key, }) } case larkim.MsgTypeFile, larkim.MsgTypeAudio: @@ -357,7 +401,6 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1) channel.InboundMessa } } - // 处理回复引用 if message.ParentId != nil && *message.ParentId != "" { msg.Reply = &channel.ReplyRef{ MessageID: *message.ParentId, diff --git a/internal/channel/adapters/local/cli.go b/internal/channel/adapters/local/cli.go index e2675207..42a6bdb7 100644 --- a/internal/channel/adapters/local/cli.go +++ b/internal/channel/adapters/local/cli.go @@ -8,18 +8,22 @@ import ( "github.com/memohai/memoh/internal/channel" ) +// CLIAdapter implements channel.Sender for the local CLI channel. type CLIAdapter struct { hub *channel.SessionHub } +// NewCLIAdapter creates a CLIAdapter backed by the given session hub. func NewCLIAdapter(hub *channel.SessionHub) *CLIAdapter { return &CLIAdapter{hub: hub} } +// Type returns the CLI channel type. func (a *CLIAdapter) Type() channel.ChannelType { return CLIType } +// Send publishes an outbound message to the CLI session hub. func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("cli hub not configured") diff --git a/internal/channel/adapters/local/descriptor.go b/internal/channel/adapters/local/descriptor.go index abf051d5..bfc99c90 100644 --- a/internal/channel/adapters/local/descriptor.go +++ b/internal/channel/adapters/local/descriptor.go @@ -1,3 +1,4 @@ +// Package local implements the CLI and Web channel adapters for local development. package local import ( @@ -7,7 +8,9 @@ import ( ) const ( + // CLIType is the registered ChannelType for the CLI adapter. CLIType channel.ChannelType = "cli" + // WebType is the registered ChannelType for the Web adapter. WebType channel.ChannelType = "web" ) diff --git a/internal/channel/adapters/local/web.go b/internal/channel/adapters/local/web.go index d0b24682..37ab217a 100644 --- a/internal/channel/adapters/local/web.go +++ b/internal/channel/adapters/local/web.go @@ -8,18 +8,22 @@ import ( "github.com/memohai/memoh/internal/channel" ) +// WebAdapter implements channel.Sender for the local Web channel. type WebAdapter struct { hub *channel.SessionHub } +// NewWebAdapter creates a WebAdapter backed by the given session hub. func NewWebAdapter(hub *channel.SessionHub) *WebAdapter { return &WebAdapter{hub: hub} } +// Type returns the Web channel type. func (a *WebAdapter) Type() channel.ChannelType { return WebType } +// Send publishes an outbound message to the Web session hub. func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { return fmt.Errorf("web hub not configured") diff --git a/internal/channel/adapters/telegram/config.go b/internal/channel/adapters/telegram/config.go index dd972838..3bfdc99d 100644 --- a/internal/channel/adapters/telegram/config.go +++ b/internal/channel/adapters/telegram/config.go @@ -7,16 +7,19 @@ import ( "github.com/memohai/memoh/internal/channel" ) +// Config holds the Telegram bot credentials extracted from a channel configuration. type Config struct { BotToken string } +// UserConfig holds the identifiers used to target a Telegram user or group. type UserConfig struct { Username string UserID string ChatID string } +// NormalizeConfig validates and normalizes a Telegram channel configuration map. func NormalizeConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseConfig(raw) if err != nil { @@ -27,6 +30,7 @@ func NormalizeConfig(raw map[string]any) (map[string]any, error) { }, nil } +// NormalizeUserConfig validates and normalizes a Telegram user-binding configuration map. func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { cfg, err := parseUserConfig(raw) if err != nil { @@ -45,6 +49,7 @@ func NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return result, nil } +// ResolveTarget derives a Telegram delivery target from a user-binding configuration. func ResolveTarget(raw map[string]any) (string, error) { cfg, err := parseUserConfig(raw) if err != nil { @@ -66,6 +71,7 @@ func ResolveTarget(raw map[string]any) (string, error) { return "", fmt.Errorf("telegram binding is incomplete") } +// MatchBinding reports whether a Telegram user binding matches the given criteria. func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { cfg, err := parseUserConfig(raw) if err != nil { @@ -88,6 +94,7 @@ func MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { return false } +// BuildUserConfig constructs a Telegram user-binding config from an Identity. func BuildUserConfig(identity channel.Identity) map[string]any { result := map[string]any{} if value := strings.TrimSpace(identity.Attribute("username")); value != "" { diff --git a/internal/channel/adapters/telegram/descriptor.go b/internal/channel/adapters/telegram/descriptor.go index 2d8dd4bb..268bc9d8 100644 --- a/internal/channel/adapters/telegram/descriptor.go +++ b/internal/channel/adapters/telegram/descriptor.go @@ -1,7 +1,9 @@ +// Package telegram implements the Telegram channel adapter. package telegram import "github.com/memohai/memoh/internal/channel" +// Type is the registered ChannelType identifier for Telegram. const Type channel.ChannelType = "telegram" func init() { diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index 9595bd0f..d838b139 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -6,6 +6,7 @@ import ( "log/slog" "strconv" "strings" + "sync" "time" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" @@ -14,23 +15,53 @@ import ( "github.com/memohai/memoh/internal/channel/adapters/common" ) +// TelegramAdapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Telegram. type TelegramAdapter struct { logger *slog.Logger + mu sync.RWMutex + bots map[string]*tgbotapi.BotAPI // keyed by bot token } +// NewTelegramAdapter creates a TelegramAdapter with the given logger. func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { if log == nil { log = slog.Default() } return &TelegramAdapter{ logger: log.With(slog.String("adapter", "telegram")), + bots: make(map[string]*tgbotapi.BotAPI), } } +func (a *TelegramAdapter) getOrCreateBot(token, configID string) (*tgbotapi.BotAPI, error) { + a.mu.RLock() + bot, ok := a.bots[token] + a.mu.RUnlock() + if ok { + return bot, nil + } + a.mu.Lock() + defer a.mu.Unlock() + if bot, ok := a.bots[token]; ok { + return bot, nil + } + bot, err := tgbotapi.NewBotAPI(token) + if err != nil { + if a.logger != nil { + a.logger.Error("create bot failed", slog.String("config_id", configID), slog.Any("error", err)) + } + return nil, err + } + a.bots[token] = bot + return bot, nil +} + +// Type returns the Telegram channel type. func (a *TelegramAdapter) Type() channel.ChannelType { return Type } +// Connect starts long-polling for Telegram updates and forwards messages to the handler. func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { a.logger.Info("start", slog.String("config_id", cfg.ID)) @@ -147,6 +178,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig return channel.NewConnection(cfg, stop), nil } +// Send delivers an outbound message to Telegram, handling text, attachments, and replies. func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { @@ -159,11 +191,8 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m if to == "" { return fmt.Errorf("telegram target is required") } - bot, err := tgbotapi.NewBotAPI(telegramCfg.BotToken) + bot, err := a.getOrCreateBot(telegramCfg.BotToken, cfg.ID) if err != nil { - if a.logger != nil { - a.logger.Error("create bot failed", slog.String("config_id", cfg.ID), slog.Any("error", err)) - } return err } if msg.Message.IsEmpty() { @@ -218,7 +247,7 @@ func resolveTelegramSender(msg *tgbotapi.Message) (string, string, map[string]st } displayName := strings.TrimSpace(msg.From.UserName) if displayName == "" { - displayName = strings.TrimSpace(strings.TrimSpace(msg.From.FirstName + " " + msg.From.LastName)) + displayName = strings.TrimSpace(msg.From.FirstName + " " + msg.From.LastName) } externalID := userID if externalID == "" { @@ -320,22 +349,28 @@ func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Att _, err := bot.Send(photo) return err case channel.AttachmentFile, "": - chatID, err := strconv.ParseInt(target, 10, 64) - if err != nil && !isChannel { - return fmt.Errorf("telegram target must be @username or chat_id") - } - document := tgbotapi.NewDocument(chatID, file) + var document tgbotapi.DocumentConfig if isChannel { - document.ChatID = 0 - document.ChannelUsername = target + document = tgbotapi.DocumentConfig{ + BaseFile: tgbotapi.BaseFile{ + BaseChat: tgbotapi.BaseChat{ChannelUsername: target}, + File: file, + }, + } + } else { + chatID, err := strconv.ParseInt(target, 10, 64) + if err != nil { + return fmt.Errorf("telegram target must be @username or chat_id") + } + document = tgbotapi.NewDocument(chatID, file) } document.Caption = caption document.ParseMode = parseMode if replyTo > 0 { document.ReplyToMessageID = replyTo } - _, err = bot.Send(document) - return err + _, sendErr := bot.Send(document) + return sendErr case channel.AttachmentAudio: audio, err := buildTelegramAudio(target, file) if err != nil { diff --git a/internal/channel/capabilities.go b/internal/channel/capabilities.go index b72a7af2..a1cf379c 100644 --- a/internal/channel/capabilities.go +++ b/internal/channel/capabilities.go @@ -1,7 +1,7 @@ package channel -// ChannelCapabilities 描述通道在功能层面的能力矩阵。 -// 该结构用于上层自适应逻辑,不依赖具体适配器实现。 +// ChannelCapabilities describes the feature matrix of a channel type. +// It is used by the outbound layer to validate message content before delivery. type ChannelCapabilities struct { Text bool `json:"text"` Markdown bool `json:"markdown"` diff --git a/internal/channel/cli_hub.go b/internal/channel/cli_hub.go index d5b47925..4599ccb8 100644 --- a/internal/channel/cli_hub.go +++ b/internal/channel/cli_hub.go @@ -6,17 +6,21 @@ import ( "github.com/google/uuid" ) +// SessionHub is a pub/sub hub that routes outbound messages to CLI/Web session subscribers. type SessionHub struct { mu sync.RWMutex sessions map[string]map[string]chan OutboundMessage } +// NewSessionHub creates an empty SessionHub. func NewSessionHub() *SessionHub { return &SessionHub{ sessions: map[string]map[string]chan OutboundMessage{}, } } +// Subscribe registers a new stream for the given session and returns a stream ID, +// a read-only channel for messages, and a cancel function to unsubscribe. func (h *SessionHub) Subscribe(sessionID string) (string, <-chan OutboundMessage, func()) { streamID := uuid.NewString() ch := make(chan OutboundMessage, 32) @@ -48,17 +52,14 @@ func (h *SessionHub) Subscribe(sessionID string) (string, <-chan OutboundMessage return streamID, ch, cancel } +// Publish delivers a message to all subscribers of the given session. +// Slow receivers are silently dropped. func (h *SessionHub) Publish(sessionID string, msg OutboundMessage) { h.mu.RLock() - streams := h.sessions[sessionID] - h.mu.RUnlock() - if len(streams) == 0 { - return - } - - for _, stream := range streams { + defer h.mu.RUnlock() + for _, ch := range h.sessions[sessionID] { select { - case stream <- msg: + case ch <- msg: default: // Drop if receiver is slow. } diff --git a/internal/channel/config.go b/internal/channel/config.go index 9f8014bc..6067cd39 100644 --- a/internal/channel/config.go +++ b/internal/channel/config.go @@ -3,9 +3,11 @@ package channel import ( "encoding/json" "fmt" - "strings" + "strconv" ) +// NormalizeChannelConfig validates and normalizes a channel configuration map +// using the registered descriptor for the given channel type. func NormalizeChannelConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { if raw == nil { raw = map[string]any{} @@ -20,6 +22,7 @@ func NormalizeChannelConfig(channelType ChannelType, raw map[string]any) (map[st return desc.NormalizeConfig(raw) } +// NormalizeChannelUserConfig validates and normalizes a user-channel binding configuration. func NormalizeChannelUserConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { if raw == nil { raw = map[string]any{} @@ -34,6 +37,7 @@ func NormalizeChannelUserConfig(channelType ChannelType, raw map[string]any) (ma return desc.NormalizeUserConfig(raw) } +// ResolveTargetFromUserConfig derives a delivery target string from a user-channel binding. func ResolveTargetFromUserConfig(channelType ChannelType, config map[string]any) (string, error) { desc, ok := GetChannelDescriptor(channelType) if !ok || desc.ResolveTarget == nil { @@ -42,6 +46,7 @@ func ResolveTargetFromUserConfig(channelType ChannelType, config map[string]any) return desc.ResolveTarget(config) } +// MatchUserBinding reports whether the given binding config matches the criteria. func MatchUserBinding(channelType ChannelType, config map[string]any, criteria BindingCriteria) bool { desc, ok := GetChannelDescriptor(channelType) if !ok || desc.MatchBinding == nil { @@ -50,6 +55,7 @@ func MatchUserBinding(channelType ChannelType, config map[string]any, criteria B return desc.MatchBinding(config, criteria) } +// BuildUserBindingConfig constructs a user-channel binding config from an Identity. func BuildUserBindingConfig(channelType ChannelType, identity Identity) map[string]any { desc, ok := GetChannelDescriptor(channelType) if !ok || desc.BuildUserConfig == nil { @@ -58,6 +64,7 @@ func BuildUserBindingConfig(channelType ChannelType, identity Identity) map[stri return desc.BuildUserConfig(identity) } +// DecodeConfigMap unmarshals a JSON byte slice into a string-keyed map. func DecodeConfigMap(raw []byte) (map[string]any, error) { if len(raw) == 0 { return map[string]any{}, nil @@ -72,17 +79,20 @@ func DecodeConfigMap(raw []byte) (map[string]any, error) { return payload, nil } +// ReadString looks up the first matching key in a map and returns its string representation. +// It tries each key in order and converts non-string values using type-safe formatting. func ReadString(raw map[string]any, keys ...string) string { for _, key := range keys { if value, ok := raw[key]; ok { switch v := value.(type) { case string: return v + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case bool: + return strconv.FormatBool(v) default: - encoded, err := json.Marshal(v) - if err == nil { - return strings.Trim(string(encoded), "\"") - } + return fmt.Sprintf("%v", v) } } } diff --git a/internal/channel/directory.go b/internal/channel/directory.go index dc629dbd..fa82e98e 100644 --- a/internal/channel/directory.go +++ b/internal/channel/directory.go @@ -2,6 +2,7 @@ package channel import "context" +// DirectoryEntryKind classifies a directory entry as a user or a group. type DirectoryEntryKind string const ( @@ -9,6 +10,7 @@ const ( DirectoryEntryGroup DirectoryEntryKind = "group" ) +// DirectoryEntry represents a single user or group discovered through the channel's directory. type DirectoryEntry struct { Kind DirectoryEntryKind `json:"kind"` ID string `json:"id"` @@ -18,12 +20,14 @@ type DirectoryEntry struct { Metadata map[string]any `json:"metadata,omitempty"` } +// DirectoryQuery contains filters for directory listing operations. type DirectoryQuery struct { Query string `json:"query,omitempty"` Limit int `json:"limit,omitempty"` Kind DirectoryEntryKind `json:"kind,omitempty"` } +// ChannelDirectoryAdapter provides contact and group lookup for a channel platform. type ChannelDirectoryAdapter interface { ListPeers(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) ListGroups(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) diff --git a/internal/channel/manager.go b/internal/channel/manager.go index 3f88263d..d46a61f3 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -10,6 +10,7 @@ import ( "time" ) +// ConfigStore abstracts the persistence layer used by the Manager. type ConfigStore interface { ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) @@ -21,9 +22,10 @@ type ConfigStore interface { UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error } -// Middleware 消息处理中间件定义 +// Middleware wraps an InboundHandler to add cross-cutting behavior. type Middleware func(next InboundHandler) InboundHandler +// Manager coordinates channel adapters, connection lifecycle, and message dispatch. type Manager struct { service ConfigStore processor InboundProcessor @@ -49,6 +51,7 @@ type connectionEntry struct { connection Connection } +// NewManager creates a Manager with the given logger, config store, and inbound processor. func NewManager(log *slog.Logger, service ConfigStore, processor InboundProcessor) *Manager { if log == nil { log = slog.Default() @@ -68,11 +71,12 @@ func NewManager(log *slog.Logger, service ConfigStore, processor InboundProcesso } } -// Use 注册中间件 +// Use appends middleware to the inbound processing chain. func (m *Manager) Use(mw ...Middleware) { m.middlewares = append(m.middlewares, mw...) } +// RegisterAdapter adds an adapter and indexes its Sender/Receiver capabilities. func (m *Manager) RegisterAdapter(adapter Adapter) { if adapter == nil { return @@ -91,7 +95,7 @@ func (m *Manager) RegisterAdapter(adapter Adapter) { } } -// AddAdapter 注册适配器并触发一次刷新(便于热插拔)。 +// AddAdapter registers an adapter and triggers an immediate refresh for hot-plug support. func (m *Manager) AddAdapter(ctx context.Context, adapter Adapter) { m.RegisterAdapter(adapter) if ctx != nil { @@ -99,7 +103,7 @@ func (m *Manager) AddAdapter(ctx context.Context, adapter Adapter) { } } -// RemoveAdapter 移除适配器并停止其连接(便于热插拔)。 +// RemoveAdapter unregisters an adapter and stops all its active connections. func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { if ctx == nil { ctx = context.Background() @@ -112,7 +116,9 @@ func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { for id, entry := range m.connections { if entry != nil && entry.config.ChannelType == normalized { if entry.connection != nil { - _ = entry.connection.Stop(ctx) + if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { + m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) + } } delete(m.connections, id) } @@ -126,6 +132,7 @@ func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { m.adapterMu.Unlock() } +// Start begins the periodic config refresh loop and inbound worker pool. func (m *Manager) Start(ctx context.Context) { if m.logger != nil { m.logger.Info("manager start") @@ -150,6 +157,7 @@ func (m *Manager) Start(ctx context.Context) { }() } +// Send delivers an outbound message to the specified channel, resolving target and config automatically. func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelType, req SendRequest) error { if m.service == nil { return fmt.Errorf("channel manager not configured") @@ -210,20 +218,20 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp return nil } +// HandleInbound enqueues an inbound message for asynchronous processing by the worker pool. func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { if m.processor == nil { return fmt.Errorf("inbound processor not configured") } + if ctx == nil { + ctx = context.Background() + } m.startInboundWorkers(ctx) if m.inboundCtx != nil && m.inboundCtx.Err() != nil { return fmt.Errorf("inbound dispatcher stopped") } - taskCtx := ctx - if ctx != nil { - taskCtx = context.WithoutCancel(ctx) - } task := inboundTask{ - ctx: taskCtx, + ctx: context.WithoutCancel(ctx), cfg: cfg, msg: msg, } @@ -282,7 +290,9 @@ func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { if m.logger != nil { m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) } - _ = entry.connection.Stop(ctx) + if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { + m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) + } } delete(m.connections, id) } @@ -295,14 +305,15 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error if receiver == nil { return nil } + m.mu.Lock() entry := m.connections[cfg.ID] - m.mu.Unlock() - + if entry != nil && !entry.config.UpdatedAt.Before(cfg.UpdatedAt) { + m.mu.Unlock() + return nil + } if entry != nil { - if entry.config.UpdatedAt.Equal(cfg.UpdatedAt) { - return nil - } + m.mu.Unlock() if m.logger != nil { m.logger.Info("adapter restart", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } @@ -318,17 +329,17 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error m.mu.Lock() delete(m.connections, cfg.ID) m.mu.Unlock() + } else { + m.mu.Unlock() } + if m.logger != nil { m.logger.Info("adapter start", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) } - - // 包装中间件 handler := m.handleInbound for i := len(m.middlewares) - 1; i >= 0; i-- { handler = m.middlewares[i](handler) } - conn, err := receiver.Connect(ctx, cfg, handler) if err != nil { return err @@ -350,7 +361,9 @@ func (m *Manager) stopAll(ctx context.Context) { if m.logger != nil { m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) } - _ = entry.connection.Stop(ctx) + if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { + m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) + } } delete(m.connections, id) } @@ -370,6 +383,7 @@ func (m *Manager) handleInbound(ctx context.Context, cfg ChannelConfig, msg Inbo return nil } +// Stop terminates the connection identified by the given config ID. func (m *Manager) Stop(ctx context.Context, configID string) error { configID = strings.TrimSpace(configID) if configID == "" { @@ -384,6 +398,7 @@ func (m *Manager) Stop(ctx context.Context, configID string) error { return entry.connection.Stop(ctx) } +// StopByBot terminates all connections belonging to the given bot. func (m *Manager) StopByBot(ctx context.Context, botID string) error { botID = strings.TrimSpace(botID) if botID == "" { @@ -402,6 +417,7 @@ func (m *Manager) StopByBot(ctx context.Context, botID string) error { return nil } +// Shutdown cancels the inbound worker pool and stops all active connections. func (m *Manager) Shutdown(ctx context.Context) error { if m.inboundCancel != nil { m.inboundCancel() @@ -474,8 +490,9 @@ func (m *Manager) resolveOutboundPolicy(channelType ChannelType) OutboundPolicy return NormalizeOutboundPolicy(policy) } +// buildOutboundMessages splits an outbound message into multiple messages based on the policy. +// The caller must pass an already-normalized policy. func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]OutboundMessage, error) { - policy = NormalizeOutboundPolicy(policy) if msg.Message.IsEmpty() { return nil, fmt.Errorf("message is required") } @@ -551,6 +568,46 @@ func normalizeOutboundMessage(msg Message) Message { return msg } +func validateMessageCapabilities(channelType ChannelType, msg Message) error { + caps, ok := GetChannelCapabilities(channelType) + if !ok { + return nil + } + switch msg.Format { + case MessageFormatPlain: + if !caps.Text { + return fmt.Errorf("channel does not support plain text") + } + case MessageFormatMarkdown: + if !caps.Markdown && !caps.RichText { + return fmt.Errorf("channel does not support markdown") + } + case MessageFormatRich: + if !caps.RichText { + return fmt.Errorf("channel does not support rich text") + } + } + if len(msg.Parts) > 0 && !caps.RichText { + return fmt.Errorf("channel does not support rich text") + } + if len(msg.Attachments) > 0 && !caps.Attachments { + return fmt.Errorf("channel does not support attachments") + } + if len(msg.Attachments) > 0 && requiresMedia(msg.Attachments) && !caps.Media { + return fmt.Errorf("channel does not support media") + } + if len(msg.Actions) > 0 && !caps.Buttons { + return fmt.Errorf("channel does not support actions") + } + if msg.Thread != nil && !caps.Threads { + return fmt.Errorf("channel does not support threads") + } + if msg.Reply != nil && !caps.Reply { + return fmt.Errorf("channel does not support reply") + } + return nil +} + func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg ChannelConfig, msg OutboundMessage, policy OutboundPolicy) error { if sender == nil { return fmt.Errorf("unsupported channel type: %s", cfg.ChannelType) @@ -562,36 +619,9 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel if msg.Message.IsEmpty() { return fmt.Errorf("message is required") } - if caps, ok := GetChannelCapabilities(cfg.ChannelType); ok { - if msg.Message.Format == MessageFormatPlain && !caps.Text { - return fmt.Errorf("channel does not support plain text") - } - if msg.Message.Format == MessageFormatMarkdown && !(caps.Markdown || caps.RichText) { - return fmt.Errorf("channel does not support markdown") - } - if msg.Message.Format == MessageFormatRich && !caps.RichText { - return fmt.Errorf("channel does not support rich text") - } - if len(msg.Message.Parts) > 0 && !caps.RichText { - return fmt.Errorf("channel does not support rich text") - } - if len(msg.Message.Attachments) > 0 && !caps.Attachments { - return fmt.Errorf("channel does not support attachments") - } - if len(msg.Message.Attachments) > 0 && requiresMedia(msg.Message.Attachments) && !caps.Media { - return fmt.Errorf("channel does not support media") - } - if len(msg.Message.Actions) > 0 && !caps.Buttons { - return fmt.Errorf("channel does not support actions") - } - if msg.Message.Thread != nil && !caps.Threads { - return fmt.Errorf("channel does not support threads") - } - if msg.Message.Reply != nil && !caps.Reply { - return fmt.Errorf("channel does not support reply") - } + if err := validateMessageCapabilities(cfg.ChannelType, msg.Message); err != nil { + return err } - policy = NormalizeOutboundPolicy(policy) var lastErr error for i := 0; i < policy.RetryMax; i++ { err := sender.Send(ctx, cfg, OutboundMessage{Target: target, Message: msg.Message}) diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go index 79653593..b16644dd 100644 --- a/internal/channel/outbound.go +++ b/internal/channel/outbound.go @@ -2,6 +2,7 @@ package channel import "strings" +// ChunkerMode selects the text chunking strategy. type ChunkerMode string const ( @@ -9,6 +10,7 @@ const ( ChunkerModeMarkdown ChunkerMode = "markdown" ) +// OutboundOrder controls the delivery order of text and media messages. type OutboundOrder string const ( @@ -16,8 +18,10 @@ const ( OutboundOrderTextFirst OutboundOrder = "text_first" ) +// Chunker splits text into pieces that respect a character limit. type Chunker func(text string, limit int) []string +// OutboundPolicy configures how outbound messages are chunked, ordered, and retried. type OutboundPolicy struct { TextChunkLimit int `json:"text_chunk_limit,omitempty"` ChunkerMode ChunkerMode `json:"chunker_mode,omitempty"` @@ -27,6 +31,7 @@ type OutboundPolicy struct { RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` } +// NormalizeOutboundPolicy fills zero-value fields with sensible defaults. func NormalizeOutboundPolicy(policy OutboundPolicy) OutboundPolicy { if policy.TextChunkLimit <= 0 { policy.TextChunkLimit = 2000 @@ -49,6 +54,7 @@ func NormalizeOutboundPolicy(policy OutboundPolicy) OutboundPolicy { return policy } +// DefaultChunker returns the built-in Chunker for the given mode. func DefaultChunker(mode ChunkerMode) Chunker { switch mode { case ChunkerModeMarkdown: @@ -58,6 +64,7 @@ func DefaultChunker(mode ChunkerMode) Chunker { } } +// ChunkText splits text at newline boundaries, respecting the rune limit. func ChunkText(text string, limit int) []string { trimmed := strings.TrimSpace(text) if trimmed == "" { @@ -99,6 +106,7 @@ func ChunkText(text string, limit int) []string { return chunks } +// ChunkMarkdownText splits text at paragraph boundaries (double newlines), respecting the rune limit. func ChunkMarkdownText(text string, limit int) []string { trimmed := strings.TrimSpace(text) if trimmed == "" { diff --git a/internal/channel/processor.go b/internal/channel/processor.go index 9d6be6e7..6f4e79f2 100644 --- a/internal/channel/processor.go +++ b/internal/channel/processor.go @@ -2,7 +2,7 @@ package channel import "context" -// InboundProcessor 负责处理入站消息并通过 sender 回传响应。 +// InboundProcessor handles inbound messages and replies through the given sender. type InboundProcessor interface { HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error } diff --git a/internal/channel/registry.go b/internal/channel/registry.go index 6826092e..f70ed507 100644 --- a/internal/channel/registry.go +++ b/internal/channel/registry.go @@ -6,6 +6,7 @@ import ( "sync" ) +// ChannelDescriptor holds all metadata and hooks for a registered channel type. type ChannelDescriptor struct { Type ChannelType DisplayName string @@ -32,6 +33,7 @@ var registry = &channelRegistry{ items: map[ChannelType]ChannelDescriptor{}, } +// RegisterChannel adds a channel descriptor to the global registry. func RegisterChannel(desc ChannelDescriptor) error { normalized := normalizeChannelType(string(desc.Type)) if normalized == "" { @@ -50,12 +52,14 @@ func RegisterChannel(desc ChannelDescriptor) error { return nil } +// MustRegisterChannel calls RegisterChannel and panics on error. func MustRegisterChannel(desc ChannelDescriptor) { if err := RegisterChannel(desc); err != nil { panic(err) } } +// UnregisterChannel removes a channel type from the global registry. func UnregisterChannel(channelType ChannelType) bool { normalized := normalizeChannelType(channelType.String()) if normalized == "" { @@ -70,6 +74,7 @@ func UnregisterChannel(channelType ChannelType) bool { return true } +// GetChannelDescriptor returns the descriptor for the given channel type. func GetChannelDescriptor(channelType ChannelType) (ChannelDescriptor, bool) { normalized := normalizeChannelType(channelType.String()) registry.mu.RLock() @@ -78,6 +83,7 @@ func GetChannelDescriptor(channelType ChannelType) (ChannelDescriptor, bool) { return desc, ok } +// ListChannelDescriptors returns all registered channel descriptors. func ListChannelDescriptors() []ChannelDescriptor { registry.mu.RLock() defer registry.mu.RUnlock() @@ -88,6 +94,7 @@ func ListChannelDescriptors() []ChannelDescriptor { return items } +// GetChannelCapabilities returns the capability matrix for the given channel type. func GetChannelCapabilities(channelType ChannelType) (ChannelCapabilities, bool) { desc, ok := GetChannelDescriptor(channelType) if !ok { @@ -96,6 +103,7 @@ func GetChannelCapabilities(channelType ChannelType) (ChannelCapabilities, bool) return desc.Capabilities, true } +// GetChannelOutboundPolicy returns the outbound policy for the given channel type. func GetChannelOutboundPolicy(channelType ChannelType) (OutboundPolicy, bool) { desc, ok := GetChannelDescriptor(channelType) if !ok { @@ -104,6 +112,7 @@ func GetChannelOutboundPolicy(channelType ChannelType) (OutboundPolicy, bool) { return desc.OutboundPolicy, true } +// GetChannelConfigSchema returns the configuration schema for the given channel type. func GetChannelConfigSchema(channelType ChannelType) (ConfigSchema, bool) { desc, ok := GetChannelDescriptor(channelType) if !ok { @@ -112,6 +121,7 @@ func GetChannelConfigSchema(channelType ChannelType) (ConfigSchema, bool) { return desc.ConfigSchema, true } +// GetChannelUserConfigSchema returns the user-binding configuration schema for the given channel type. func GetChannelUserConfigSchema(channelType ChannelType) (ConfigSchema, bool) { desc, ok := GetChannelDescriptor(channelType) if !ok { @@ -120,6 +130,7 @@ func GetChannelUserConfigSchema(channelType ChannelType) (ConfigSchema, bool) { return desc.UserConfigSchema, true } +// IsConfigless reports whether the channel type operates without per-bot configuration. func IsConfigless(channelType ChannelType) bool { desc, ok := GetChannelDescriptor(channelType) if !ok { diff --git a/internal/channel/schema.go b/internal/channel/schema.go index cb9e7db7..2f818270 100644 --- a/internal/channel/schema.go +++ b/internal/channel/schema.go @@ -1,5 +1,6 @@ package channel +// FieldType enumerates the supported configuration field types. type FieldType string const ( @@ -10,7 +11,7 @@ const ( FieldEnum FieldType = "enum" ) -// FieldSchema 定义单个配置字段的结构化描述。 +// FieldSchema describes a single configuration field. type FieldSchema struct { Type FieldType `json:"type"` Required bool `json:"required"` @@ -20,7 +21,7 @@ type FieldSchema struct { Example any `json:"example,omitempty"` } -// ConfigSchema 描述通道配置或用户绑定的结构。 +// ConfigSchema describes the structure of a channel or user-binding configuration. type ConfigSchema struct { Version int `json:"version"` Fields map[string]FieldSchema `json:"fields"` diff --git a/internal/channel/service.go b/internal/channel/service.go index f3cdd25e..5a79ae90 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -15,14 +15,17 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Service provides CRUD operations for channel configurations, user bindings, and sessions. type Service struct { queries *sqlc.Queries } +// NewService creates a Service backed by the given database queries. func NewService(queries *sqlc.Queries) *Service { return &Service{queries: queries} } +// UpsertConfig creates or updates a bot's channel configuration. func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { if s.queries == nil { return ChannelConfig{}, fmt.Errorf("channel queries not configured") @@ -58,14 +61,6 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch if err != nil { return ChannelConfig{}, err } - capabilities := req.Capabilities - if capabilities == nil { - capabilities = map[string]any{} - } - capabilitiesPayload, err := json.Marshal(capabilities) - if err != nil { - return ChannelConfig{}, err - } status := strings.TrimSpace(req.Status) if status == "" { status = "pending" @@ -85,7 +80,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch }, SelfIdentity: selfPayload, Routing: routingPayload, - Capabilities: capabilitiesPayload, + Capabilities: []byte("{}"), Status: status, VerifiedAt: verifiedAt, }) @@ -95,6 +90,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch return normalizeChannelConfig(row) } +// UpsertUserConfig creates or updates a user's channel binding. func (s *Service) UpsertUserConfig(ctx context.Context, actorUserID string, channelType ChannelType, req UpsertUserConfigRequest) (ChannelUserBinding, error) { if s.queries == nil { return ChannelUserBinding{}, fmt.Errorf("channel queries not configured") @@ -125,6 +121,8 @@ func (s *Service) UpsertUserConfig(ctx context.Context, actorUserID string, chan return normalizeChannelUserBindingRow(row) } +// ResolveEffectiveConfig returns the active channel configuration for a bot. +// For configless channel types, a synthetic config is returned. func (s *Service) ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { if s.queries == nil { return ChannelConfig{}, fmt.Errorf("channel queries not configured") @@ -156,6 +154,7 @@ func (s *Service) ResolveEffectiveConfig(ctx context.Context, botID string, chan return ChannelConfig{}, fmt.Errorf("channel config not found") } +// ListConfigsByType returns all channel configurations of the given type. func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { if s.queries == nil { return nil, fmt.Errorf("channel queries not configured") @@ -178,6 +177,7 @@ func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType return items, nil } +// GetUserConfig returns the user's channel binding for the given channel type. func (s *Service) GetUserConfig(ctx context.Context, actorUserID string, channelType ChannelType) (ChannelUserBinding, error) { if s.queries == nil { return ChannelUserBinding{}, fmt.Errorf("channel queries not configured") @@ -213,6 +213,7 @@ func (s *Service) GetUserConfig(ctx context.Context, actorUserID string, channel }, nil } +// ListUserConfigsByType returns all user bindings for the given channel type. func (s *Service) ListUserConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelUserBinding, error) { if s.queries == nil { return nil, fmt.Errorf("channel queries not configured") @@ -223,7 +224,7 @@ func (s *Service) ListUserConfigsByType(ctx context.Context, channelType Channel } items := make([]ChannelUserBinding, 0, len(rows)) for _, row := range rows { - item, err := normalizeChannelUserBindingListRow(row) + item, err := normalizeChannelUserBindingRow(row) if err != nil { return nil, err } @@ -232,6 +233,7 @@ func (s *Service) ListUserConfigsByType(ctx context.Context, channelType Channel return items, nil } +// GetChannelSession returns the session with the given ID, or an empty session if not found. func (s *Service) GetChannelSession(ctx context.Context, sessionID string) (ChannelSession, error) { if s.queries == nil { return ChannelSession{}, fmt.Errorf("channel queries not configured") @@ -246,6 +248,7 @@ func (s *Service) GetChannelSession(ctx context.Context, sessionID string) (Chan return normalizeChannelSession(row) } +// ListSessionsByBotPlatform returns all sessions for the given bot and platform. func (s *Service) ListSessionsByBotPlatform(ctx context.Context, botID, platform string) ([]ChannelSession, error) { if s.queries == nil { return nil, fmt.Errorf("channel queries not configured") @@ -280,6 +283,7 @@ func (s *Service) ListSessionsByBotPlatform(ctx context.Context, botID, platform return items, nil } +// UpsertChannelSession creates or updates a channel session record. func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, botID string, channelConfigID string, userID string, contactID string, platform string, replyTarget string, threadID string, metadata map[string]any) error { if s.queries == nil { return fmt.Errorf("channel queries not configured") @@ -339,6 +343,7 @@ func (s *Service) UpsertChannelSession(ctx context.Context, sessionID string, bo return err } +// ResolveUserBinding finds the user ID whose channel binding matches the given criteria. func (s *Service) ResolveUserBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { rows, err := s.ListUserConfigsByType(ctx, channelType) if err != nil { @@ -368,10 +373,6 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { if err != nil { return ChannelConfig{}, err } - capabilities, err := DecodeConfigMap(row.Capabilities) - if err != nil { - return ChannelConfig{}, err - } verifiedAt := time.Time{} if row.VerifiedAt.Valid { verifiedAt = row.VerifiedAt.Time @@ -387,9 +388,8 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { Credentials: credentials, ExternalIdentity: externalIdentity, SelfIdentity: selfIdentity, - Routing: routing, - Capabilities: capabilities, - Status: strings.TrimSpace(row.Status), + Routing: routing, + Status: strings.TrimSpace(row.Status), VerifiedAt: verifiedAt, CreatedAt: timeFromPg(row.CreatedAt), UpdatedAt: timeFromPg(row.UpdatedAt), @@ -411,21 +411,6 @@ func normalizeChannelUserBindingRow(row sqlc.UserChannelBinding) (ChannelUserBin }, nil } -func normalizeChannelUserBindingListRow(row sqlc.UserChannelBinding) (ChannelUserBinding, error) { - config, err := DecodeConfigMap(row.Config) - if err != nil { - return ChannelUserBinding{}, err - } - return ChannelUserBinding{ - ID: toUUIDString(row.ID), - ChannelType: ChannelType(row.ChannelType), - UserID: toUUIDString(row.UserID), - Config: config, - CreatedAt: timeFromPg(row.CreatedAt), - UpdatedAt: timeFromPg(row.UpdatedAt), - }, nil -} - func normalizeChannelSession(row sqlc.ChannelSession) (ChannelSession, error) { metadata, err := DecodeConfigMap(row.Metadata) if err != nil { @@ -475,6 +460,7 @@ func timeFromPg(value pgtype.Timestamptz) time.Time { return time.Time{} } +// String returns the channel type as a plain string. func (c ChannelType) String() string { return string(c) } diff --git a/internal/channel/target.go b/internal/channel/target.go index 67a88548..28bf465b 100644 --- a/internal/channel/target.go +++ b/internal/channel/target.go @@ -2,16 +2,20 @@ package channel import "strings" +// TargetHint provides a display label and example for a target format. type TargetHint struct { Example string `json:"example,omitempty"` Label string `json:"label,omitempty"` } +// TargetSpec describes the expected format of a delivery target for a channel type. type TargetSpec struct { Format string `json:"format"` Hints []TargetHint `json:"hints,omitempty"` } +// NormalizeTarget applies the channel-specific target normalization function. +// It returns the normalized string and true if a normalizer was found, otherwise the trimmed input and false. func NormalizeTarget(channelType ChannelType, raw string) (string, bool) { desc, ok := GetChannelDescriptor(channelType) if !ok || desc.NormalizeTarget == nil { diff --git a/internal/channel/types.go b/internal/channel/types.go index 3b66934c..79ab3b53 100644 --- a/internal/channel/types.go +++ b/internal/channel/types.go @@ -1,3 +1,5 @@ +// Package channel provides a unified abstraction for multi-platform messaging channels. +// It defines types, interfaces, and a registry for channel adapters such as Telegram and Feishu. package channel import ( @@ -6,8 +8,10 @@ import ( "time" ) +// ChannelType identifies a messaging platform (e.g., "telegram", "feishu"). type ChannelType string +// ParseChannelType validates and normalizes a raw string into a registered ChannelType. func ParseChannelType(raw string) (ChannelType, error) { normalized := normalizeChannelType(raw) if normalized == "" { @@ -19,12 +23,14 @@ func ParseChannelType(raw string) (ChannelType, error) { return normalized, nil } +// Identity represents a sender's identity on a channel. type Identity struct { ExternalID string DisplayName string Attributes map[string]string } +// Attribute returns the trimmed value for the given key, or empty string if absent. func (i Identity) Attribute(key string) string { if i.Attributes == nil { return "" @@ -32,6 +38,7 @@ func (i Identity) Attribute(key string) string { return strings.TrimSpace(i.Attributes[key]) } +// Conversation holds metadata about the chat or group context. type Conversation struct { ID string Type string @@ -40,6 +47,7 @@ type Conversation struct { Metadata map[string]any } +// InboundMessage is a message received from an external channel. type InboundMessage struct { Channel ChannelType Message Message @@ -53,7 +61,8 @@ type InboundMessage struct { Metadata map[string]any } -// SessionID 结构: platform:bot_id:conversation_id[:sender_id] +// SessionID returns a stable identifier for the conversation session. +// Format: platform:bot_id:conversation_id[:sender_id]. func (m InboundMessage) SessionID() string { if strings.TrimSpace(m.SessionKey) != "" { return strings.TrimSpace(m.SessionKey) @@ -65,10 +74,10 @@ func (m InboundMessage) SessionID() string { return GenerateSessionID(string(m.Channel), m.BotID, m.Conversation.ID, m.Conversation.Type, senderID) } -// GenerateSessionID 统一生成 SessionID 的逻辑 +// GenerateSessionID builds a session identifier from platform, bot, conversation, and sender info. +// For group chats, the sender ID is appended to provide per-user context. func GenerateSessionID(platform, botID, conversationID, conversationType, senderID string) string { parts := []string{platform, botID, conversationID} - // 如果是群聊,增加发送者 ID 以支持个人上下文 ct := strings.ToLower(strings.TrimSpace(conversationType)) if ct != "" && ct != "p2p" && ct != "private" { senderID = strings.TrimSpace(senderID) @@ -79,11 +88,13 @@ func GenerateSessionID(platform, botID, conversationID, conversationType, sender return strings.Join(parts, ":") } +// OutboundMessage pairs a delivery target with the message content. type OutboundMessage struct { Target string `json:"target"` Message Message `json:"message"` } +// MessageFormat indicates how the message text should be rendered. type MessageFormat string const ( @@ -92,6 +103,7 @@ const ( MessageFormatRich MessageFormat = "rich" ) +// MessagePartType identifies the kind of a rich-text message part. type MessagePartType string const ( @@ -102,6 +114,7 @@ const ( MessagePartEmoji MessagePartType = "emoji" ) +// MessageTextStyle describes inline formatting for a text part. type MessageTextStyle string const ( @@ -111,6 +124,7 @@ const ( MessageStyleCode MessageTextStyle = "code" ) +// MessagePart is a single element within a rich-text message. type MessagePart struct { Type MessagePartType `json:"type"` Text string `json:"text,omitempty"` @@ -122,6 +136,7 @@ type MessagePart struct { Metadata map[string]any `json:"metadata,omitempty"` } +// AttachmentType classifies the kind of binary attachment. type AttachmentType string const ( @@ -133,6 +148,7 @@ const ( AttachmentGIF AttachmentType = "gif" ) +// Attachment represents a binary file attached to a message. type Attachment struct { Type AttachmentType `json:"type"` URL string `json:"url,omitempty"` @@ -147,6 +163,7 @@ type Attachment struct { Metadata map[string]any `json:"metadata,omitempty"` } +// Action describes an interactive button or link in a message. type Action struct { Type string `json:"type"` Label string `json:"label,omitempty"` @@ -154,15 +171,18 @@ type Action struct { URL string `json:"url,omitempty"` } +// ThreadRef references a conversation thread by ID. type ThreadRef struct { ID string `json:"id"` } +// ReplyRef points to a message being replied to. type ReplyRef struct { Target string `json:"target,omitempty"` MessageID string `json:"message_id,omitempty"` } +// Message is the unified message structure used across all channels. type Message struct { ID string `json:"id,omitempty"` Format MessageFormat `json:"format,omitempty"` @@ -175,6 +195,7 @@ type Message struct { Metadata map[string]any `json:"metadata,omitempty"` } +// IsEmpty reports whether the message carries no content. func (m Message) IsEmpty() bool { return strings.TrimSpace(m.Text) == "" && len(m.Parts) == 0 && @@ -182,6 +203,7 @@ func (m Message) IsEmpty() bool { len(m.Actions) == 0 } +// PlainText extracts the plain text representation of the message. func (m Message) PlainText() string { if strings.TrimSpace(m.Text) != "" { return strings.TrimSpace(m.Text) @@ -211,11 +233,13 @@ func (m Message) PlainText() string { return strings.Join(lines, "\n") } +// BindingCriteria specifies conditions for matching a user-channel binding. type BindingCriteria struct { ExternalID string Attributes map[string]string } +// Attribute returns the trimmed value for the given key, or empty string if absent. func (c BindingCriteria) Attribute(key string) string { if c.Attributes == nil { return "" @@ -223,6 +247,7 @@ func (c BindingCriteria) Attribute(key string) string { return strings.TrimSpace(c.Attributes[key]) } +// BindingCriteriaFromIdentity creates BindingCriteria from a channel Identity. func BindingCriteriaFromIdentity(identity Identity) BindingCriteria { return BindingCriteria{ ExternalID: strings.TrimSpace(identity.ExternalID), @@ -230,6 +255,7 @@ func BindingCriteriaFromIdentity(identity Identity) BindingCriteria { } } +// ChannelConfig holds the configuration for a bot's channel integration. type ChannelConfig struct { ID string BotID string @@ -238,13 +264,13 @@ type ChannelConfig struct { ExternalIdentity string SelfIdentity map[string]any Routing map[string]any - Capabilities map[string]any Status string VerifiedAt time.Time CreatedAt time.Time UpdatedAt time.Time } +// ChannelUserBinding represents a user's binding to a specific channel type. type ChannelUserBinding struct { ID string ChannelType ChannelType @@ -254,20 +280,22 @@ type ChannelUserBinding struct { UpdatedAt time.Time } +// UpsertConfigRequest is the input for creating or updating a channel configuration. type UpsertConfigRequest struct { Credentials map[string]any `json:"credentials"` ExternalIdentity string `json:"external_identity,omitempty"` SelfIdentity map[string]any `json:"self_identity,omitempty"` Routing map[string]any `json:"routing,omitempty"` - Capabilities map[string]any `json:"capabilities,omitempty"` Status string `json:"status,omitempty"` VerifiedAt *time.Time `json:"verified_at,omitempty"` } +// UpsertUserConfigRequest is the input for creating or updating a user-channel binding. type UpsertUserConfigRequest struct { Config map[string]any `json:"config"` } +// ChannelSession tracks an active conversation session on a channel. type ChannelSession struct { SessionID string BotID string @@ -282,6 +310,7 @@ type ChannelSession struct { UpdatedAt time.Time } +// SendRequest is the input for sending an outbound message through a channel. type SendRequest struct { Target string `json:"target,omitempty"` UserID string `json:"user_id,omitempty"` diff --git a/internal/chat/assistant_output.go b/internal/chat/assistant_output.go index a36f15cb..ae00f91f 100644 --- a/internal/chat/assistant_output.go +++ b/internal/chat/assistant_output.go @@ -2,51 +2,35 @@ package chat import "strings" -type AssistantOutput struct { - Content string - Parts []ContentPart -} - -func ExtractAssistantOutputs(messages []GatewayMessage) []AssistantOutput { +// ExtractAssistantOutputs collects assistant-role outputs from a slice of ModelMessages. +func ExtractAssistantOutputs(messages []ModelMessage) []AssistantOutput { if len(messages) == 0 { return nil } outputs := make([]AssistantOutput, 0, len(messages)) for _, msg := range messages { - normalized := normalizeGatewayMessage(msg) - for _, item := range normalized { - if item.Role != "assistant" { - continue - } - content := strings.TrimSpace(item.Content) - parts := make([]ContentPart, 0, len(item.Parts)) - for _, part := range item.Parts { - if !hasContentPartValue(part) { - continue - } - parts = append(parts, part) - } - if content == "" && len(parts) == 0 { - continue - } - outputs = append(outputs, AssistantOutput{ - Content: content, - Parts: parts, - }) + if msg.Role != "assistant" { + continue } + content := strings.TrimSpace(msg.TextContent()) + parts := filterContentParts(msg.ContentParts()) + if content == "" && len(parts) == 0 { + continue + } + outputs = append(outputs, AssistantOutput{Content: content, Parts: parts}) } return outputs } -func hasContentPartValue(part ContentPart) bool { - if strings.TrimSpace(part.Text) != "" { - return true +func filterContentParts(parts []ContentPart) []ContentPart { + if len(parts) == 0 { + return nil } - if strings.TrimSpace(part.URL) != "" { - return true + filtered := make([]ContentPart, 0, len(parts)) + for _, p := range parts { + if p.HasValue() { + filtered = append(filtered, p) + } } - if strings.TrimSpace(part.Emoji) != "" { - return true - } - return false + return filtered } diff --git a/internal/chat/chat.go b/internal/chat/chat.go deleted file mode 100644 index 91c1a427..00000000 --- a/internal/chat/chat.go +++ /dev/null @@ -1 +0,0 @@ -package chat \ No newline at end of file diff --git a/internal/chat/normalize.go b/internal/chat/normalize.go deleted file mode 100644 index 8f3962d9..00000000 --- a/internal/chat/normalize.go +++ /dev/null @@ -1,407 +0,0 @@ -package chat - -import ( - "encoding/json" - "strings" -) - -type toolResult struct { - ToolCallID string - Content string -} - -func normalizeGatewayMessages(messages []GatewayMessage) []GatewayMessage { - normalized := make([]GatewayMessage, 0, len(messages)) - for _, msg := range messages { - items := normalizeGatewayMessage(msg) - normalized = append(normalized, toGatewayMessages(items)...) - } - return normalized -} - -func normalizeGatewayMessage(msg GatewayMessage) []NormalizedMessage { - if msg == nil { - return nil - } - role := getString(msg["role"]) - if role == "" { - role = "assistant" - } - - var toolCalls []ToolCall - var textParts []ContentPart - var toolResults []toolResult - - if rawCalls, ok := msg["tool_calls"].([]any); ok { - for _, raw := range rawCalls { - if call := normalizeToolCall(raw); call.Function.Name != "" { - toolCalls = append(toolCalls, call) - } - } - } - - switch content := msg["content"].(type) { - case string: - if strings.TrimSpace(content) != "" || len(toolCalls) > 0 { - normalized := NormalizedMessage{Role: role} - if strings.TrimSpace(content) != "" { - normalized.Content = content - } - if len(toolCalls) > 0 { - normalized.ToolCalls = toolCalls - } - return appendToolResults([]NormalizedMessage{normalized}, toolResults) - } - case []any: - for _, part := range content { - switch p := part.(type) { - case string: - if strings.TrimSpace(p) != "" { - textParts = append(textParts, ContentPart{Type: "text", Text: p}) - } - case map[string]any: - if contentPart, ok := normalizeContentPart(p); ok { - textParts = append(textParts, contentPart) - continue - } - if call := normalizeToolCall(p); call.Function.Name != "" { - toolCalls = append(toolCalls, call) - continue - } - if result := normalizeToolResult(p); result.ToolCallID != "" { - toolResults = append(toolResults, result) - continue - } - if encoded := toJSONString(p); encoded != "" { - textParts = append(textParts, ContentPart{Type: "text", Text: encoded}) - } - default: - if encoded := toJSONString(p); encoded != "" { - textParts = append(textParts, ContentPart{Type: "text", Text: encoded}) - } - } - } - case map[string]any: - if contentPart, ok := normalizeContentPart(content); ok { - textParts = append(textParts, contentPart) - } else if encoded := toJSONString(content); encoded != "" { - textParts = append(textParts, ContentPart{Type: "text", Text: encoded}) - } - } - - if len(textParts) == 0 && len(toolCalls) == 0 && len(toolResults) == 0 { - return nil - } - - output := NormalizedMessage{Role: role} - if len(toolCalls) > 0 { - output.ToolCalls = toolCalls - } - if len(textParts) == 1 && len(toolCalls) == 0 { - output.Content = textParts[0].Text - } else if len(textParts) > 0 { - output.Parts = textParts - } - - return appendToolResults([]NormalizedMessage{output}, toolResults) -} - -func appendToolResults(messages []NormalizedMessage, results []toolResult) []NormalizedMessage { - if len(results) == 0 { - return messages - } - for _, result := range results { - if strings.TrimSpace(result.ToolCallID) == "" { - continue - } - item := NormalizedMessage{ - Role: "tool", - ToolCallID: result.ToolCallID, - } - if strings.TrimSpace(result.Content) != "" { - item.Content = result.Content - } - messages = append(messages, item) - } - return messages -} - -func normalizeTextPart(part map[string]any) string { - if part == nil { - return "" - } - if partType, _ := part["type"].(string); partType == "text" { - if text, ok := part["text"].(string); ok { - return text - } - } - if text, ok := part["text"].(string); ok && strings.TrimSpace(text) != "" { - return text - } - return "" -} - -func normalizeContentPart(part map[string]any) (ContentPart, bool) { - if part == nil { - return ContentPart{}, false - } - partType := getString(part["type"]) - if partType == "" { - partType = "text" - } - if partType == "tool_use" || partType == "tool-call" || partType == "function_call" || partType == "tool_result" || partType == "tool-result" { - return ContentPart{}, false - } - text := normalizeTextPart(part) - url := getString(part["url"]) - emoji := getString(part["emoji"]) - if strings.TrimSpace(text) == "" && strings.TrimSpace(url) == "" && strings.TrimSpace(emoji) == "" { - return ContentPart{}, false - } - styles := normalizeStringSlice(part["styles"]) - metadata := map[string]any{} - if raw, ok := part["metadata"].(map[string]any); ok && raw != nil { - metadata = raw - } - return ContentPart{ - Type: partType, - Text: text, - URL: url, - Styles: styles, - Language: getString(part["language"]), - UserID: getString(part["user_id"]), - Emoji: emoji, - Metadata: metadata, - }, true -} - -func normalizeStringSlice(raw any) []string { - switch value := raw.(type) { - case []string: - return value - case []any: - items := make([]string, 0, len(value)) - for _, entry := range value { - if str, ok := entry.(string); ok && strings.TrimSpace(str) != "" { - items = append(items, strings.TrimSpace(str)) - } - } - return items - default: - return nil - } -} - -func normalizeToolCall(part any) ToolCall { - switch value := part.(type) { - case map[string]any: - if valueType, _ := value["type"].(string); valueType == "tool_use" || valueType == "tool-call" || valueType == "function_call" { - return ToolCall{ - ID: getString(value["id"]), - Type: "function", - Function: ToolCallFunction{ - Name: getString(value["name"]), - Arguments: toJSONString(value["input"], value["args"], value["arguments"]), - }, - } - } - if fc, ok := value["function_call"].(map[string]any); ok { - return ToolCall{ - ID: getString(value["id"]), - Type: "function", - Function: ToolCallFunction{ - Name: getString(fc["name"]), - Arguments: toJSONString(fc["arguments"], fc["args"]), - }, - } - } - if fc, ok := value["functionCall"].(map[string]any); ok { - return ToolCall{ - ID: getString(value["id"]), - Type: "function", - Function: ToolCallFunction{ - Name: getString(fc["name"]), - Arguments: toJSONString(fc["args"], fc["arguments"]), - }, - } - } - if fn, ok := value["function"].(map[string]any); ok { - return ToolCall{ - ID: getString(value["id"]), - Type: "function", - Function: ToolCallFunction{ - Name: getString(fn["name"]), - Arguments: toJSONString(fn["arguments"]), - }, - } - } - } - return ToolCall{} -} - -func normalizeToolResult(part map[string]any) toolResult { - if part == nil { - return toolResult{} - } - if partType, _ := part["type"].(string); partType == "tool_result" || partType == "tool-result" { - return toolResult{ - ToolCallID: firstString(part["tool_use_id"], part["toolCallId"], part["tool_call_id"], part["id"]), - Content: normalizeToolResultContent(part["content"], part["result"], part["output"]), - } - } - if raw, ok := part["toolResult"].(map[string]any); ok { - return toolResult{ - ToolCallID: firstString(raw["toolUseId"], raw["tool_call_id"], raw["id"]), - Content: normalizeToolResultContent(raw["content"], raw["output"], raw["result"]), - } - } - if raw, ok := part["functionResponse"].(map[string]any); ok { - return toolResult{ - ToolCallID: firstString(raw["id"]), - Content: normalizeToolResultContent(raw["response"], raw["output"], raw["result"]), - } - } - return toolResult{} -} - -func normalizeToolResultContent(values ...any) string { - for _, value := range values { - if value == nil { - continue - } - switch v := value.(type) { - case string: - if strings.TrimSpace(v) != "" { - return v - } - case []any: - parts := make([]string, 0, len(v)) - for _, item := range v { - switch itemValue := item.(type) { - case string: - if strings.TrimSpace(itemValue) != "" { - parts = append(parts, itemValue) - } - case map[string]any: - if text := normalizeTextPart(itemValue); text != "" { - parts = append(parts, text) - } else if encoded := toJSONString(itemValue); encoded != "" { - parts = append(parts, encoded) - } - default: - if encoded := toJSONString(itemValue); encoded != "" { - parts = append(parts, encoded) - } - } - } - if len(parts) > 0 { - return strings.Join(parts, "\n") - } - case map[string]any: - if text := normalizeTextPart(v); text != "" { - return text - } - if encoded := toJSONString(v); encoded != "" { - return encoded - } - default: - if encoded := toJSONString(v); encoded != "" { - return encoded - } - } - } - return "" -} - -func toGatewayMessages(messages []NormalizedMessage) []GatewayMessage { - converted := make([]GatewayMessage, 0, len(messages)) - for _, msg := range messages { - item := GatewayMessage{ - "role": msg.Role, - } - if strings.TrimSpace(msg.Content) != "" { - item["content"] = msg.Content - } else if len(msg.Parts) > 0 { - parts := make([]map[string]any, 0, len(msg.Parts)) - for _, part := range msg.Parts { - entry := map[string]any{ - "type": part.Type, - } - if strings.TrimSpace(part.Text) != "" { - entry["text"] = part.Text - } - parts = append(parts, entry) - } - item["content"] = parts - } - if len(msg.ToolCalls) > 0 { - payload := make([]map[string]any, 0, len(msg.ToolCalls)) - for _, call := range msg.ToolCalls { - if strings.TrimSpace(call.Function.Name) == "" { - continue - } - entry := map[string]any{ - "type": "function", - "function": map[string]any{ - "name": call.Function.Name, - "arguments": call.Function.Arguments, - }, - } - if strings.TrimSpace(call.ID) != "" { - entry["id"] = call.ID - } - payload = append(payload, entry) - } - if len(payload) > 0 { - item["tool_calls"] = payload - } - } - if strings.TrimSpace(msg.ToolCallID) != "" { - item["tool_call_id"] = msg.ToolCallID - } - if strings.TrimSpace(msg.Name) != "" { - item["name"] = msg.Name - } - converted = append(converted, item) - } - return converted -} - -func getString(value any) string { - if raw, ok := value.(string); ok { - return raw - } - return "" -} - -func firstString(values ...any) string { - for _, value := range values { - if raw, ok := value.(string); ok && strings.TrimSpace(raw) != "" { - return raw - } - } - return "" -} - -func toJSONString(values ...any) string { - for _, value := range values { - if value == nil { - continue - } - if raw, ok := value.(string); ok { - if strings.TrimSpace(raw) != "" { - return raw - } - continue - } - encoded, err := json.Marshal(value) - if err != nil { - continue - } - if strings.TrimSpace(string(encoded)) == "" { - continue - } - return string(encoded) - } - return "" -} diff --git a/internal/chat/resolver.go b/internal/chat/resolver.go index 35c01286..b6bf8042 100644 --- a/internal/chat/resolver.go +++ b/internal/chat/resolver.go @@ -18,6 +18,7 @@ import ( "github.com/memohai/memoh/internal/history" "github.com/memohai/memoh/internal/memory" "github.com/memohai/memoh/internal/models" + "github.com/memohai/memoh/internal/schedule" "github.com/memohai/memoh/internal/settings" ) @@ -37,15 +38,17 @@ type Resolver struct { streamingClient *http.Client } -type userSettings struct { - ChatModelID string - MemoryModelID string - EmbeddingModelID string - MaxContextLoadTime int - Language string -} - -func NewResolver(log *slog.Logger, modelsService *models.Service, queries *sqlc.Queries, memoryService *memory.Service, historyService *history.Service, settingsService *settings.Service, gatewayBaseURL string, timeout time.Duration) *Resolver { +// NewResolver creates a Resolver that communicates with the agent gateway. +func NewResolver( + log *slog.Logger, + modelsService *models.Service, + queries *sqlc.Queries, + memoryService *memory.Service, + historyService *history.Service, + settingsService *settings.Service, + gatewayBaseURL string, + timeout time.Duration, +) *Resolver { if strings.TrimSpace(gatewayBaseURL) == "" { gatewayBaseURL = "http://127.0.0.1:8081" } @@ -62,14 +65,12 @@ func NewResolver(log *slog.Logger, modelsService *models.Service, queries *sqlc. gatewayBaseURL: gatewayBaseURL, timeout: timeout, logger: log.With(slog.String("service", "chat")), - httpClient: &http.Client{ - Timeout: timeout, - }, + httpClient: &http.Client{Timeout: timeout}, streamingClient: &http.Client{}, } } -// ---------- gateway payload types ---------- +// --- gateway payload --- type gatewayModelConfig struct { ModelID string `json:"modelId"` @@ -92,83 +93,82 @@ type gatewayIdentity struct { SessionToken string `json:"sessionToken,omitempty"` } -type agentGatewayRequest struct { +type gatewayRequest struct { Model gatewayModelConfig `json:"model"` ActiveContextTime int `json:"activeContextTime"` Channels []string `json:"channels"` CurrentChannel string `json:"currentChannel"` AllowedActions []string `json:"allowedActions,omitempty"` - Messages []GatewayMessage `json:"messages"` + Messages []ModelMessage `json:"messages"` Skills []string `json:"skills"` Query string `json:"query"` Identity gatewayIdentity `json:"identity"` + Attachments []any `json:"attachments"` } -type agentGatewayResponse struct { - Messages []GatewayMessage `json:"messages"` - Skills []string `json:"skills"` +type gatewayResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` } -// ---------- Chat ---------- +// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) --- -func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { +type resolvedContext struct { + payload gatewayRequest + model models.GetResponse + provider sqlc.LlmProvider +} + +func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContext, error) { if strings.TrimSpace(req.Query) == "" { - return ChatResponse{}, fmt.Errorf("query is required") + return resolvedContext{}, fmt.Errorf("query is required") } if strings.TrimSpace(req.BotID) == "" { - return ChatResponse{}, fmt.Errorf("bot id is required") + return resolvedContext{}, fmt.Errorf("bot id is required") } if strings.TrimSpace(req.SessionID) == "" { - return ChatResponse{}, fmt.Errorf("session id is required") + return resolvedContext{}, fmt.Errorf("session id is required") } + skipHistory := req.MaxContextLoadTime < 0 - settings, err := r.loadUserSettings(ctx, req.UserID) + userSettings, err := r.loadUserSettings(ctx, req.UserID) if err != nil { - return ChatResponse{}, err + return resolvedContext{}, err } - chatModel, provider, err := r.selectChatModel(ctx, req, settings) + chatModel, provider, err := r.selectChatModel(ctx, req, userSettings) if err != nil { - return ChatResponse{}, err + return resolvedContext{}, err } clientType, err := normalizeClientType(provider.ClientType) if err != nil { - return ChatResponse{}, err + return resolvedContext{}, err } - maxContextLoadTime, language, err := r.loadBotSettings(ctx, req.BotID) + botSettings, err := r.loadBotSettings(ctx, req.BotID) if err != nil { - return ChatResponse{}, err - } - if req.MaxContextLoadTime > 0 { - maxContextLoadTime = req.MaxContextLoadTime - } - if strings.TrimSpace(req.Language) != "" { - language = req.Language + return resolvedContext{}, err } + maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes) - var messages []GatewayMessage + var messages []ModelMessage var historySkills []string if !skipHistory { - messages, err = r.loadHistoryMessages(ctx, req.BotID, req.SessionID, maxContextLoadTime) + messages, err = r.loadHistoryMessages(ctx, req.BotID, req.SessionID, maxCtx) if err != nil { - return ChatResponse{}, err + return resolvedContext{}, err } - historySkills, err = r.loadHistorySkills(ctx, req.BotID, req.SessionID, maxContextLoadTime) + historySkills, err = r.loadHistorySkills(ctx, req.BotID, req.SessionID, maxCtx) if err != nil { - return ChatResponse{}, err + return resolvedContext{}, err } } - if len(req.Messages) > 0 { - messages = append(messages, req.Messages...) - } - messages = sanitizeGatewayMessages(messages) - messages = normalizeGatewayMessagesForModel(messages) - skills := normalizeSkills(append(historySkills, req.Skills...)) - + messages = append(messages, req.Messages...) + messages = sanitizeMessages(messages) + skills := dedup(append(historySkills, req.Skills...)) containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) - payload := agentGatewayRequest{ + payload := gatewayRequest{ Model: gatewayModelConfig{ ModelID: chatModel.ModelID, ClientType: clientType, @@ -176,318 +176,179 @@ func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, err APIKey: provider.ApiKey, BaseURL: provider.BaseUrl, }, - ActiveContextTime: normalizeMaxContextLoad(maxContextLoadTime), - Channels: req.Channels, + ActiveContextTime: maxCtx, + Channels: nonNilStrings(req.Channels), CurrentChannel: req.CurrentChannel, AllowedActions: req.AllowedActions, - Messages: messages, - Skills: skills, + Messages: nonNilMessages(messages), + Skills: nonNilStrings(skills), Query: req.Query, Identity: gatewayIdentity{ BotID: req.BotID, SessionID: req.SessionID, ContainerID: containerID, - ContactID: defaultString(req.ContactID, req.UserID, req.BotID), - ContactName: defaultString(req.ContactName, "User"), + ContactID: firstNonEmpty(req.ContactID, req.UserID, req.BotID), + ContactName: firstNonEmpty(req.ContactName, "User"), ContactAlias: req.ContactAlias, UserID: req.UserID, CurrentPlatform: req.CurrentChannel, ReplyTarget: req.ReplyTarget, SessionToken: req.SessionToken, }, + Attachments: []any{}, } - _ = language // language is embedded in system prompt by the gateway - resp, err := r.postChat(ctx, payload, req.Token) + return resolvedContext{payload: payload, model: chatModel, provider: provider}, nil +} + +// --- Chat --- + +// Chat sends a synchronous chat request to the agent gateway and stores the result. +func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) { + rc, err := r.resolve(ctx, req) if err != nil { return ChatResponse{}, err } - resp.Messages = normalizeGatewayMessages(resp.Messages) - - if err := r.storeHistory(ctx, req.BotID, req.SessionID, req.Query, resp.Messages, resp.Skills); err != nil { + resp, err := r.postChat(ctx, rc.payload, req.Token) + if err != nil { return ChatResponse{}, err } - if err := r.storeMemory(ctx, req.BotID, req.SessionID, req.Query, resp.Messages); err != nil { + if err := r.storeRound(ctx, req.BotID, req.SessionID, req.Query, resp.Messages, resp.Skills); err != nil { return ChatResponse{}, err } - return ChatResponse{ Messages: resp.Messages, Skills: resp.Skills, - Model: chatModel.ModelID, - Provider: provider.ClientType, + Model: rc.model.ModelID, + Provider: rc.provider.ClientType, }, nil } -// ---------- TriggerSchedule ---------- +// --- TriggerSchedule --- -func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, schedule SchedulePayload, token string) error { +// TriggerSchedule executes a scheduled command through the chat gateway. +func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { if strings.TrimSpace(botID) == "" { return fmt.Errorf("bot id is required") } - if strings.TrimSpace(schedule.Command) == "" { + if strings.TrimSpace(payload.Command) == "" { return fmt.Errorf("schedule command is required") } - req := ChatRequest{ BotID: botID, - SessionID: "schedule:" + schedule.ID, - Query: schedule.Command, + SessionID: "schedule:" + payload.ID, + Query: payload.Command, + Token: token, } - settings, err := r.loadUserSettings(ctx, "") - if err != nil { - return err - } - chatModel, provider, err := r.selectChatModel(ctx, req, settings) - if err != nil { - return err - } - clientType, err := normalizeClientType(provider.ClientType) + rc, err := r.resolve(ctx, req) if err != nil { return err } + rc.payload.Identity.ContactID = botID + rc.payload.Identity.ContactName = "Scheduler" - maxContextLoadTime, _, err := r.loadBotSettings(ctx, botID) + resp, err := r.postChat(ctx, rc.payload, token) if err != nil { return err } - - messages, err := r.loadHistoryMessages(ctx, botID, req.SessionID, maxContextLoadTime) - if err != nil { - return err - } - historySkills, err := r.loadHistorySkills(ctx, botID, req.SessionID, maxContextLoadTime) - if err != nil { - return err - } - skills := normalizeSkills(historySkills) - containerID := r.resolveContainerID(ctx, botID, "") - - payload := agentGatewayRequest{ - Model: gatewayModelConfig{ - ModelID: chatModel.ModelID, - ClientType: clientType, - Input: chatModel.Input, - APIKey: provider.ApiKey, - BaseURL: provider.BaseUrl, - }, - ActiveContextTime: normalizeMaxContextLoad(maxContextLoadTime), - Messages: messages, - Skills: skills, - Query: schedule.Command, - Identity: gatewayIdentity{ - BotID: botID, - SessionID: req.SessionID, - ContainerID: containerID, - ContactID: botID, - ContactName: "Scheduler", - }, - } - - resp, err := r.postChat(ctx, payload, token) - if err != nil { - return err - } - resp.Messages = normalizeGatewayMessages(resp.Messages) - if err := r.storeHistory(ctx, botID, req.SessionID, schedule.Command, resp.Messages, resp.Skills); err != nil { - return err - } - if err := r.storeMemory(ctx, botID, req.SessionID, schedule.Command, resp.Messages); err != nil { - return err - } - return nil + return r.storeRound(ctx, botID, req.SessionID, payload.Command, resp.Messages, resp.Skills) } -// ---------- StreamChat ---------- +// --- StreamChat --- +// StreamChat sends a streaming chat request to the agent gateway. func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) { - chunkChan := make(chan StreamChunk) - errChan := make(chan error, 1) + chunkCh := make(chan StreamChunk) + errCh := make(chan error, 1) go func() { - defer close(chunkChan) - defer close(errChan) + defer close(chunkCh) + defer close(errCh) - if strings.TrimSpace(req.Query) == "" { - errChan <- fmt.Errorf("query is required") - return - } - if strings.TrimSpace(req.BotID) == "" { - errChan <- fmt.Errorf("bot id is required") - return - } - if strings.TrimSpace(req.SessionID) == "" { - errChan <- fmt.Errorf("session id is required") - return - } - skipHistory := req.MaxContextLoadTime < 0 - - settings, err := r.loadUserSettings(ctx, req.UserID) + rc, err := r.resolve(ctx, req) if err != nil { - errChan <- err + errCh <- err return } - chatModel, provider, err := r.selectChatModel(ctx, req, settings) - if err != nil { - errChan <- err - return - } - clientType, err := normalizeClientType(provider.ClientType) - if err != nil { - errChan <- err - return - } - - maxContextLoadTime, language, err := r.loadBotSettings(ctx, req.BotID) - if err != nil { - errChan <- err - return - } - if req.MaxContextLoadTime > 0 { - maxContextLoadTime = req.MaxContextLoadTime - } - if strings.TrimSpace(req.Language) != "" { - language = req.Language - } - - var messages []GatewayMessage - var historySkills []string - if !skipHistory { - messages, err = r.loadHistoryMessages(ctx, req.BotID, req.SessionID, maxContextLoadTime) - if err != nil { - errChan <- err - return - } - historySkills, err = r.loadHistorySkills(ctx, req.BotID, req.SessionID, maxContextLoadTime) - if err != nil { - errChan <- err - return - } - } - if len(req.Messages) > 0 { - messages = append(messages, req.Messages...) - } - messages = sanitizeGatewayMessages(messages) - messages = normalizeGatewayMessagesForModel(messages) - skills := normalizeSkills(append(historySkills, req.Skills...)) - containerID := r.resolveContainerID(ctx, req.BotID, req.ContainerID) - - payload := agentGatewayRequest{ - Model: gatewayModelConfig{ - ModelID: chatModel.ModelID, - ClientType: clientType, - Input: chatModel.Input, - APIKey: provider.ApiKey, - BaseURL: provider.BaseUrl, - }, - ActiveContextTime: normalizeMaxContextLoad(maxContextLoadTime), - Channels: req.Channels, - CurrentChannel: req.CurrentChannel, - AllowedActions: req.AllowedActions, - Messages: messages, - Skills: skills, - Query: req.Query, - Identity: gatewayIdentity{ - BotID: req.BotID, - SessionID: req.SessionID, - ContainerID: containerID, - ContactID: defaultString(req.ContactID, req.UserID, req.BotID), - ContactName: defaultString(req.ContactName, "User"), - ContactAlias: req.ContactAlias, - UserID: req.UserID, - CurrentPlatform: req.CurrentChannel, - ReplyTarget: req.ReplyTarget, - SessionToken: req.SessionToken, - }, - } - _ = language - - if err := r.streamChat(ctx, payload, req.BotID, req.SessionID, req.Query, req.Token, chunkChan); err != nil { - errChan <- err - return + if err := r.streamChat(ctx, rc.payload, req.BotID, req.SessionID, req.Query, req.Token, chunkCh); err != nil { + errCh <- err } }() - - return chunkChan, errChan + return chunkCh, errCh } -// ---------- HTTP helpers ---------- +// --- HTTP helpers --- -func (r *Resolver) postChat(ctx context.Context, payload agentGatewayRequest, token string) (agentGatewayResponse, error) { +func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token string) (gatewayResponse, error) { body, err := json.Marshal(payload) if err != nil { - return agentGatewayResponse{}, err + return gatewayResponse{}, err } url := r.gatewayBaseURL + "/chat/" r.logger.Info("gateway request", slog.String("url", url), slog.String("body_prefix", truncate(string(body), 200))) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return agentGatewayResponse{}, err + return gatewayResponse{}, err } - req.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Content-Type", "application/json") if strings.TrimSpace(token) != "" { - req.Header.Set("Authorization", token) + httpReq.Header.Set("Authorization", token) } - resp, err := r.httpClient.Do(req) + resp, err := r.httpClient.Do(httpReq) if err != nil { - return agentGatewayResponse{}, err + return gatewayResponse{}, err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return agentGatewayResponse{}, err + return gatewayResponse{}, err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - r.logger.Error("gateway request failed", - slog.String("url", url), - slog.Int("status", resp.StatusCode), - slog.String("body_prefix", truncate(string(respBody), 300)), - ) - return agentGatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) + r.logger.Error("gateway error", slog.String("url", url), slog.Int("status", resp.StatusCode), slog.String("body_prefix", truncate(string(respBody), 300))) + return gatewayResponse{}, fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(respBody))) } - parsed, err := parseAgentGatewayResponse(respBody) - if err != nil { - r.logger.Error("failed to parse agent gateway response", slog.String("body", string(respBody)), slog.Any("error", err)) - return agentGatewayResponse{}, err + var parsed gatewayResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + r.logger.Error("gateway response parse failed", slog.String("body_prefix", truncate(string(respBody), 300)), slog.Any("error", err)) + return gatewayResponse{}, fmt.Errorf("failed to parse gateway response: %w", err) } return parsed, nil } -func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, botID, sessionID, query, token string, chunkChan chan<- StreamChunk) error { +func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, botID, sessionID, query, token string, chunkCh chan<- StreamChunk) error { body, err := json.Marshal(payload) if err != nil { return err } - url := r.gatewayBaseURL + "/chat/stream" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, r.gatewayBaseURL+"/chat/stream", bytes.NewReader(body)) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") if strings.TrimSpace(token) != "" { - req.Header.Set("Authorization", token) + httpReq.Header.Set("Authorization", token) } - resp, err := r.streamingClient.Do(req) + resp, err := r.streamingClient.Do(httpReq) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - payload, _ := io.ReadAll(resp.Body) - return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(payload))) + errBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("agent gateway error: %s", strings.TrimSpace(string(errBody))) } scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) - currentEventType := "" + currentEvent := "" stored := false for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -495,7 +356,7 @@ func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, continue } if strings.HasPrefix(line, "event:") { - currentEventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) continue } if !strings.HasPrefix(line, "data:") { @@ -505,26 +366,58 @@ func (r *Resolver) streamChat(ctx context.Context, payload agentGatewayRequest, if data == "" || data == "[DONE]" { continue } - chunkChan <- StreamChunk([]byte(data)) + chunkCh <- StreamChunk([]byte(data)) if stored { continue } - - if handled, err := r.tryStoreFromStreamPayload(ctx, botID, sessionID, query, currentEventType, data); err != nil { - return err + if handled, storeErr := r.tryStoreStream(ctx, botID, sessionID, query, currentEvent, data); storeErr != nil { + return storeErr } else if handled { stored = true } } - - if err := scanner.Err(); err != nil { - return err - } - return nil + return scanner.Err() } -// ---------- container resolution ---------- +// tryStoreStream attempts to extract final messages from a stream event and persist them. +func (r *Resolver) tryStoreStream(ctx context.Context, botID, sessionID, query, eventType, data string) (bool, error) { + // event: done + data: {messages: [...]} + if eventType == "done" { + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + } + } + + // data: {"type":"agent_end"|"done", ...} + var envelope struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills"` + } + if err := json.Unmarshal([]byte(data), &envelope); err == nil { + if envelope.Type == "agent_end" && len(envelope.Messages) > 0 { + return true, r.storeRound(ctx, botID, sessionID, query, envelope.Messages, envelope.Skills) + } + if envelope.Type == "done" && len(envelope.Data) > 0 { + var resp gatewayResponse + if err := json.Unmarshal(envelope.Data, &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + } + } + } + + // fallback: data: {messages: [...]} + var resp gatewayResponse + if err := json.Unmarshal([]byte(data), &resp); err == nil && len(resp.Messages) > 0 { + return true, r.storeRound(ctx, botID, sessionID, query, resp.Messages, resp.Skills) + } + return false, nil +} + +// --- container resolution --- func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit string) string { if strings.TrimSpace(explicit) != "" { @@ -542,377 +435,242 @@ func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit strin return "mcp-" + botID } -// ---------- history helpers ---------- +// --- history helpers --- -func (r *Resolver) loadHistoryMessages(ctx context.Context, botID, sessionID string, maxContextLoadTime int) ([]GatewayMessage, error) { +func (r *Resolver) loadHistoryMessages(ctx context.Context, botID, sessionID string, maxContextMinutes int) ([]ModelMessage, error) { if r.historyService == nil { return nil, fmt.Errorf("history service not configured") } - from := time.Now().UTC().Add(-time.Duration(normalizeMaxContextLoad(maxContextLoadTime)) * time.Minute) - records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, from) + since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) + records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, since) if err != nil { return nil, err } - messages := make([]GatewayMessage, 0, len(records)) + var messages []ModelMessage for _, record := range records { - if len(record.Messages) == 0 { + msgs, err := recordToMessages(record) + if err != nil { + r.logger.Warn("skip malformed history record", slog.String("record_id", record.ID), slog.Any("error", err)) continue } - for _, msg := range record.Messages { - if msg == nil { - continue - } - messages = append(messages, GatewayMessage(msg)) - } + messages = append(messages, msgs...) } return messages, nil } -func (r *Resolver) loadHistorySkills(ctx context.Context, botID, sessionID string, maxContextLoadTime int) ([]string, error) { +func (r *Resolver) loadHistorySkills(ctx context.Context, botID, sessionID string, maxContextMinutes int) ([]string, error) { if r.historyService == nil { return nil, fmt.Errorf("history service not configured") } - from := time.Now().UTC().Add(-time.Duration(normalizeMaxContextLoad(maxContextLoadTime)) * time.Minute) - records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, from) + since := time.Now().UTC().Add(-time.Duration(maxContextMinutes) * time.Minute) + records, err := r.historyService.ListBySessionSince(ctx, botID, sessionID, since) if err != nil { return nil, err } - combined := make([]string, 0, len(records)) + var combined []string for _, record := range records { - if len(record.Skills) == 0 { - continue - } combined = append(combined, record.Skills...) } - return normalizeSkills(combined), nil + return dedup(combined), nil } -// ---------- store helpers ---------- +// recordToMessages converts a history record (stored as []map[string]any) to typed ModelMessages. +func recordToMessages(record history.Record) ([]ModelMessage, error) { + if len(record.Messages) == 0 { + return nil, nil + } + raw, err := json.Marshal(record.Messages) + if err != nil { + return nil, err + } + var msgs []ModelMessage + if err := json.Unmarshal(raw, &msgs); err != nil { + return nil, err + } + return msgs, nil +} -func (r *Resolver) storeHistory(ctx context.Context, botID, sessionID, query string, responseMessages []GatewayMessage, skills []string) error { +// --- store helpers --- + +func (r *Resolver) storeRound(ctx context.Context, botID, sessionID, query string, messages []ModelMessage, skills []string) error { + if err := r.storeHistory(ctx, botID, sessionID, query, messages, skills); err != nil { + return err + } + r.storeMemory(ctx, botID, sessionID, query, messages) + return nil +} + +func (r *Resolver) storeHistory(ctx context.Context, botID, sessionID, query string, messages []ModelMessage, skills []string) error { if r.historyService == nil { return fmt.Errorf("history service not configured") } - if strings.TrimSpace(botID) == "" { - return fmt.Errorf("bot id is required") + if strings.TrimSpace(botID) == "" || strings.TrimSpace(sessionID) == "" { + return fmt.Errorf("bot id and session id are required") } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return fmt.Errorf("session id is required") - } - if strings.TrimSpace(query) == "" && len(responseMessages) == 0 { + if strings.TrimSpace(query) == "" && len(messages) == 0 { return nil } - messages := make([]map[string]any, 0, len(responseMessages)) - for _, msg := range responseMessages { - if msg == nil { - continue - } - messages = append(messages, map[string]any(msg)) + // Convert typed messages to []map[string]any for the history service. + raw, err := json.Marshal(messages) + if err != nil { + return err } - metadata := map[string]any{ - "query": strings.TrimSpace(query), + var rows []map[string]any + if err := json.Unmarshal(raw, &rows); err != nil { + return err } - _, err := r.historyService.Create(ctx, botID, trimmedSession, history.CreateRequest{ - Messages: messages, - Metadata: metadata, + _, err = r.historyService.Create(ctx, botID, strings.TrimSpace(sessionID), history.CreateRequest{ + Messages: rows, + Metadata: map[string]any{"query": strings.TrimSpace(query)}, Skills: skills, }) return err } -func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query string, responseMessages []GatewayMessage) error { +func (r *Resolver) storeMemory(ctx context.Context, botID, sessionID, query string, messages []ModelMessage) { if r.memoryService == nil { - return nil + return } - if strings.TrimSpace(botID) == "" { - return fmt.Errorf("bot id is required") + if strings.TrimSpace(botID) == "" || strings.TrimSpace(sessionID) == "" { + return } - trimmedSession := strings.TrimSpace(sessionID) - if trimmedSession == "" { - return fmt.Errorf("session id is required") - } - if strings.TrimSpace(query) == "" && len(responseMessages) == 0 { - return nil - } - - memoryMessages := make([]memory.Message, 0, len(responseMessages)) - for _, msg := range responseMessages { - role, content := gatewayMessageToMemory(msg) - if strings.TrimSpace(content) == "" { + memMsgs := make([]memory.Message, 0, len(messages)) + for _, msg := range messages { + text := strings.TrimSpace(msg.TextContent()) + if text == "" { continue } - memoryMessages = append(memoryMessages, memory.Message{ - Role: role, - Content: content, - }) + role := msg.Role + if strings.TrimSpace(role) == "" { + role = "assistant" + } + memMsgs = append(memMsgs, memory.Message{Role: role, Content: text}) } - if len(memoryMessages) == 0 { - return nil + if len(memMsgs) == 0 { + return } - - _, err := r.memoryService.Add(ctx, memory.AddRequest{ - Messages: memoryMessages, + if _, err := r.memoryService.Add(ctx, memory.AddRequest{ + Messages: memMsgs, BotID: botID, - SessionID: trimmedSession, - }) - return err + SessionID: strings.TrimSpace(sessionID), + }); err != nil { + r.logger.Warn("store memory failed", slog.Any("error", err)) + } } -func (r *Resolver) tryStoreFromStreamPayload(ctx context.Context, botID, sessionID, query, eventType, data string) (bool, error) { - // Case 1: event: done + data: {messages: [...]} - if eventType == "done" { - if parsed, ok := parseGatewayResponse([]byte(data)); ok { - parsed.Messages = normalizeGatewayMessages(parsed.Messages) - return r.storeRound(ctx, botID, sessionID, query, parsed.Messages, parsed.Skills) - } - } +// --- model selection --- - // Case 2: data: {"type":"agent_end","messages":[...],"skills":[...]} - var envelope struct { - Type string `json:"type"` - Data json.RawMessage `json:"data"` - Messages json.RawMessage `json:"messages"` - Skills []string `json:"skills"` - } - if err := json.Unmarshal([]byte(data), &envelope); err == nil { - if envelope.Type == "agent_end" { - // agent_end with inline messages - if len(envelope.Messages) > 0 { - if parsed, ok := parseGatewayResponseFromRaw(envelope.Messages, envelope.Skills); ok { - parsed.Messages = normalizeGatewayMessages(parsed.Messages) - return r.storeRound(ctx, botID, sessionID, query, parsed.Messages, parsed.Skills) - } - } - } - if envelope.Type == "done" && len(envelope.Data) > 0 { - if parsed, ok := parseGatewayResponse(envelope.Data); ok { - parsed.Messages = normalizeGatewayMessages(parsed.Messages) - return r.storeRound(ctx, botID, sessionID, query, parsed.Messages, parsed.Skills) - } - } - } - - // Case 3: data: {messages:[...]} without event - if parsed, ok := parseGatewayResponse([]byte(data)); ok { - parsed.Messages = normalizeGatewayMessages(parsed.Messages) - return r.storeRound(ctx, botID, sessionID, query, parsed.Messages, parsed.Skills) - } - return false, nil -} - -func parseGatewayResponse(payload []byte) (agentGatewayResponse, bool) { - parsed, err := parseAgentGatewayResponse(payload) - if err != nil { - return agentGatewayResponse{}, false - } - if len(parsed.Messages) == 0 { - return agentGatewayResponse{}, false - } - return parsed, true -} - -func parseGatewayResponseFromRaw(messagesRaw json.RawMessage, skills []string) (agentGatewayResponse, bool) { - var rawMessages []json.RawMessage - if err := json.Unmarshal(messagesRaw, &rawMessages); err != nil { - return agentGatewayResponse{}, false - } - messages := make([]GatewayMessage, 0, len(rawMessages)) - for _, rawMsg := range rawMessages { - var msg map[string]any - if err := json.Unmarshal(rawMsg, &msg); err != nil { - continue - } - messages = append(messages, GatewayMessage(msg)) - } - if len(messages) == 0 { - return agentGatewayResponse{}, false - } - return agentGatewayResponse{Messages: messages, Skills: skills}, true -} - -// parseAgentGatewayResponse parses the agent gateway response with flexible message handling. -func parseAgentGatewayResponse(payload []byte) (agentGatewayResponse, error) { - var raw struct { - Messages []json.RawMessage `json:"messages"` - Skills []string `json:"skills"` - } - if err := json.Unmarshal(payload, &raw); err != nil { - return agentGatewayResponse{}, fmt.Errorf("failed to parse response structure: %w", err) - } - - messages := make([]GatewayMessage, 0, len(raw.Messages)) - for _, rawMsg := range raw.Messages { - var msg map[string]any - if err := json.Unmarshal(rawMsg, &msg); err != nil { - var arr []any - if err := json.Unmarshal(rawMsg, &arr); err == nil { - for _, item := range arr { - if m, ok := item.(map[string]any); ok { - messages = append(messages, GatewayMessage(m)) - } - } - continue - } - continue - } - messages = append(messages, GatewayMessage(msg)) - } - - return agentGatewayResponse{ - Messages: messages, - Skills: raw.Skills, - }, nil -} - -func (r *Resolver) storeRound(ctx context.Context, botID, sessionID, query string, messages []GatewayMessage, skills []string) (bool, error) { - if err := r.storeHistory(ctx, botID, sessionID, query, messages, skills); err != nil { - return true, err - } - if err := r.storeMemory(ctx, botID, sessionID, query, messages); err != nil { - return true, err - } - return true, nil -} - -// ---------- model selection ---------- - -func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, settings userSettings) (models.GetResponse, sqlc.LlmProvider, error) { +func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, us resolvedUserSettings) (models.GetResponse, sqlc.LlmProvider, error) { if r.modelsService == nil { return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") } modelID := strings.TrimSpace(req.Model) providerFilter := strings.TrimSpace(req.Provider) + // Priority: request model > user settings > first available. + if modelID == "" && providerFilter == "" && strings.TrimSpace(us.ChatModelID) != "" { + modelID = us.ChatModelID + } + if modelID != "" && providerFilter == "" { - model, err := r.modelsService.GetByModelID(ctx, modelID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - if model.Type != models.ModelTypeChat { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") - } - provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return model, provider, nil + return r.fetchChatModel(ctx, modelID) } - if providerFilter == "" && modelID == "" && strings.TrimSpace(settings.ChatModelID) != "" { - selected, err := r.modelsService.GetByModelID(ctx, settings.ChatModelID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model not found: %w", err) - } - if selected.Type != models.ModelTypeChat { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("settings chat model is not a chat model") - } - provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) - if err != nil { - return models.GetResponse{}, sqlc.LlmProvider{}, err - } - return selected, provider, nil - } - - var candidates []models.GetResponse - var err error - if providerFilter != "" { - candidates, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) - } else { - candidates, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) - } + candidates, err := r.listCandidates(ctx, providerFilter) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } - - filtered := make([]models.GetResponse, 0, len(candidates)) - for _, model := range candidates { - if model.Type != models.ModelTypeChat { - continue - } - filtered = append(filtered, model) - } - if len(filtered) == 0 { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available") - } - if modelID != "" { - for _, model := range filtered { - if model.ModelID == modelID { - provider, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + for _, m := range candidates { + if m.ModelID == modelID { + prov, err := models.FetchProviderByID(ctx, r.queries, m.LlmProviderID) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } - return model, provider, nil + return m, prov, nil } } return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not found") } - - selected := filtered[0] - provider, err := models.FetchProviderByID(ctx, r.queries, selected.LlmProviderID) + if len(candidates) == 0 { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available") + } + prov, err := models.FetchProviderByID(ctx, r.queries, candidates[0].LlmProviderID) if err != nil { return models.GetResponse{}, sqlc.LlmProvider{}, err } - return selected, provider, nil + return candidates[0], prov, nil } -// ---------- settings helpers ---------- - -func normalizeMaxContextLoad(value int) int { - if value <= 0 { - return defaultMaxContextMinutes - } - return value -} - -func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (userSettings, error) { - defaults := userSettings{ - MaxContextLoadTime: defaultMaxContextMinutes, - Language: settings.DefaultLanguage, - } - if r.settingsService == nil || strings.TrimSpace(userID) == "" { - return defaults, nil - } - settingsRow, err := r.settingsService.Get(ctx, userID) +func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { + model, err := r.modelsService.GetByModelID(ctx, modelID) if err != nil { - return userSettings{}, err + return models.GetResponse{}, sqlc.LlmProvider{}, err } - maxLoad := settingsRow.MaxContextLoadTime - if maxLoad <= 0 { - maxLoad = defaultMaxContextMinutes + if model.Type != models.ModelTypeChat { + return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") } - language := strings.TrimSpace(settingsRow.Language) - if language == "" || language == "auto" { - language = settings.DefaultLanguage + prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) + if err != nil { + return models.GetResponse{}, sqlc.LlmProvider{}, err } - return userSettings{ - ChatModelID: strings.TrimSpace(settingsRow.ChatModelID), - MemoryModelID: strings.TrimSpace(settingsRow.MemoryModelID), - EmbeddingModelID: strings.TrimSpace(settingsRow.EmbeddingModelID), - MaxContextLoadTime: maxLoad, - Language: language, + return model, prov, nil +} + +func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([]models.GetResponse, error) { + var all []models.GetResponse + var err error + if providerFilter != "" { + all, err = r.modelsService.ListByClientType(ctx, models.ClientType(providerFilter)) + } else { + all, err = r.modelsService.ListByType(ctx, models.ModelTypeChat) + } + if err != nil { + return nil, err + } + filtered := make([]models.GetResponse, 0, len(all)) + for _, m := range all { + if m.Type == models.ModelTypeChat { + filtered = append(filtered, m) + } + } + return filtered, nil +} + +// --- settings --- + +type resolvedUserSettings struct { + ChatModelID string +} + +func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (resolvedUserSettings, error) { + if r.settingsService == nil || strings.TrimSpace(userID) == "" { + return resolvedUserSettings{}, nil + } + s, err := r.settingsService.Get(ctx, userID) + if err != nil { + return resolvedUserSettings{}, err + } + return resolvedUserSettings{ + ChatModelID: strings.TrimSpace(s.ChatModelID), }, nil } -func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (int, string, error) { +func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) { if r.settingsService == nil { - return settings.DefaultMaxContextLoadTime, settings.DefaultLanguage, nil + return settings.Settings{ + MaxContextLoadTime: settings.DefaultMaxContextLoadTime, + Language: settings.DefaultLanguage, + }, nil } - settingsRow, err := r.settingsService.GetBot(ctx, botID) - if err != nil { - return 0, "", err - } - return settingsRow.MaxContextLoadTime, settingsRow.Language, nil + return r.settingsService.GetBot(ctx, botID) } -// ---------- utility ---------- +// --- utility --- func normalizeClientType(clientType string) (string, error) { switch strings.ToLower(strings.TrimSpace(clientType)) { - case "openai": - return "openai", nil - case "openai-compat": + case "openai", "openai-compat": return "openai", nil case "anthropic": return "anthropic", nil @@ -923,51 +681,13 @@ func normalizeClientType(clientType string) (string, error) { } } -func normalizeSkills(skills []string) []string { - seen := map[string]struct{}{} - normalized := make([]string, 0, len(skills)) - for _, skill := range skills { - trimmed := strings.TrimSpace(skill) - if trimmed == "" { - continue - } - if _, ok := seen[trimmed]; ok { - continue - } - seen[trimmed] = struct{}{} - normalized = append(normalized, trimmed) - } - return normalized -} - -func gatewayMessageToMemory(msg GatewayMessage) (string, string) { - role := "assistant" - if raw, ok := msg["role"].(string); ok && strings.TrimSpace(raw) != "" { - role = raw - } - if raw, ok := msg["content"]; ok { - switch v := raw.(type) { - case string: - return role, v - default: - if encoded, err := json.Marshal(v); err == nil { - return role, string(encoded) - } - } - } - if encoded, err := json.Marshal(msg); err == nil { - return role, string(encoded) - } - return role, "" -} - -func sanitizeGatewayMessages(messages []GatewayMessage) []GatewayMessage { - if len(messages) == 0 { - return messages - } - cleaned := make([]GatewayMessage, 0, len(messages)) +func sanitizeMessages(messages []ModelMessage) []ModelMessage { + cleaned := make([]ModelMessage, 0, len(messages)) for _, msg := range messages { - if !isMeaningfulGatewayMessage(msg) { + if strings.TrimSpace(msg.Role) == "" { + continue + } + if !msg.HasContent() && strings.TrimSpace(msg.ToolCallID) == "" { continue } cleaned = append(cleaned, msg) @@ -975,91 +695,53 @@ func sanitizeGatewayMessages(messages []GatewayMessage) []GatewayMessage { return cleaned } -func normalizeGatewayMessagesForModel(messages []GatewayMessage) []GatewayMessage { - if len(messages) == 0 { - return messages - } - cleaned := make([]GatewayMessage, 0, len(messages)) - for _, msg := range messages { - if msg == nil { +func dedup(items []string) []string { + seen := make(map[string]struct{}, len(items)) + result := make([]string, 0, len(items)) + for _, s := range items { + trimmed := strings.TrimSpace(s) + if trimmed == "" { continue } - role, content := gatewayMessageToMemory(msg) - content = strings.TrimSpace(content) - if content == "" { + if _, ok := seen[trimmed]; ok { continue } - if strings.TrimSpace(role) == "" { - role = "assistant" - } - if role == "tool" { - role = "assistant" - content = "[tool] " + content - } - cleaned = append(cleaned, GatewayMessage{ - "role": role, - "content": content, - }) + seen[trimmed] = struct{}{} + result = append(result, trimmed) } - return cleaned + return result } -func isMeaningfulGatewayMessage(msg GatewayMessage) bool { - if len(msg) == 0 { - return false - } - if raw, ok := msg["role"].(string); ok && strings.TrimSpace(raw) != "" { - return true - } - if raw, ok := msg["content"]; ok { - switch v := raw.(type) { - case string: - if strings.TrimSpace(v) != "" { - return true - } - default: - if !isEmptyValue(v) { - return true - } +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v } } - for _, value := range msg { - if !isEmptyValue(value) { - return true - } - } - return false + return "" } -func isEmptyValue(value any) bool { - switch v := value.(type) { - case nil: - return true - case string: - return strings.TrimSpace(v) == "" - case []any: - if len(v) == 0 { - return true +func coalescePositiveInt(values ...int) int { + for _, v := range values { + if v > 0 { + return v } - for _, item := range v { - if !isEmptyValue(item) { - return false - } - } - return true - case map[string]any: - if len(v) == 0 { - return true - } - for _, item := range v { - if !isEmptyValue(item) { - return false - } - } - return true - default: - return false } + return defaultMaxContextMinutes +} + +func nonNilStrings(s []string) []string { + if s == nil { + return []string{} + } + return s +} + +func nonNilMessages(m []ModelMessage) []ModelMessage { + if m == nil { + return []ModelMessage{} + } + return m } func truncate(s string, n int) string { @@ -1069,31 +751,14 @@ func truncate(s string, n int) string { return s[:n] + "..." } -func defaultString(values ...string) string { - for _, v := range values { - if strings.TrimSpace(v) != "" { - return v - } - } - return "" -} - func parseUUID(id string) (pgtype.UUID, error) { - parsed, err := parseUUIDHelper(id) - if err != nil { - return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) - } - return parsed, nil -} - -func parseUUIDHelper(id string) (pgtype.UUID, error) { trimmed := strings.TrimSpace(id) if trimmed == "" { return pgtype.UUID{}, fmt.Errorf("empty id") } var pgID pgtype.UUID if err := pgID.Scan(trimmed); err != nil { - return pgtype.UUID{}, err + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) } return pgID, nil } diff --git a/internal/chat/schedule_gateway.go b/internal/chat/schedule_gateway.go index d6ddd28c..d1578065 100644 --- a/internal/chat/schedule_gateway.go +++ b/internal/chat/schedule_gateway.go @@ -7,25 +7,20 @@ import ( "github.com/memohai/memoh/internal/schedule" ) -// ScheduleGateway 将 schedule 触发请求转交给 chat Resolver。 +// ScheduleGateway adapts schedule trigger calls to the chat Resolver. type ScheduleGateway struct { resolver *Resolver } +// NewScheduleGateway creates a ScheduleGateway backed by the given Resolver. func NewScheduleGateway(resolver *Resolver) *ScheduleGateway { return &ScheduleGateway{resolver: resolver} } +// TriggerSchedule delegates a schedule trigger to the chat Resolver. func (g *ScheduleGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { if g == nil || g.resolver == nil { return fmt.Errorf("chat resolver not configured") } - return g.resolver.TriggerSchedule(ctx, botID, SchedulePayload{ - ID: payload.ID, - Name: payload.Name, - Description: payload.Description, - Pattern: payload.Pattern, - MaxCalls: payload.MaxCalls, - Command: payload.Command, - }, token) + return g.resolver.TriggerSchedule(ctx, botID, payload, token) } diff --git a/internal/chat/types.go b/internal/chat/types.go index 47a32a0c..51bb60fe 100644 --- a/internal/chat/types.go +++ b/internal/chat/types.go @@ -1,65 +1,77 @@ +// Package chat orchestrates conversations with the agent gateway, including +// synchronous and streaming chat, scheduled triggers, history, and memory storage. package chat -import "encoding/json" +import ( + "encoding/json" + "strings" +) -type Message struct { - Role string `json:"role"` - Content string `json:"content"` +// ModelMessage is the canonical message format exchanged with the agent gateway. +// Aligned with Vercel AI SDK ModelMessage structure. +type ModelMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` } -type GatewayMessage map[string]any - -type ChatRequest struct { - BotID string `json:"-"` - SessionID string `json:"-"` - Token string `json:"-"` - UserID string `json:"-"` - ContainerID string `json:"-"` - ContactID string `json:"-"` - ContactName string `json:"-"` - ContactAlias string `json:"-"` - ReplyTarget string `json:"-"` - SessionToken string `json:"-"` - Query string `json:"query"` - Model string `json:"model,omitempty"` - Provider string `json:"provider,omitempty"` - MaxContextLoadTime int `json:"max_context_load_time,omitempty"` - Language string `json:"language,omitempty"` - Channels []string `json:"channels,omitempty"` - CurrentChannel string `json:"current_channel,omitempty"` - Messages []GatewayMessage `json:"messages,omitempty"` - Skills []string `json:"skills,omitempty"` - AllowedActions []string `json:"allowed_actions,omitempty"` +// TextContent extracts the plain text from the message content. +// If content is a string, it returns it directly. +// If content is an array of parts, it joins all text-type parts. +func (m ModelMessage) TextContent() string { + if len(m.Content) == 0 { + return "" + } + var s string + if err := json.Unmarshal(m.Content, &s); err == nil { + return s + } + var parts []ContentPart + if err := json.Unmarshal(m.Content, &parts); err == nil { + texts := make([]string, 0, len(parts)) + for _, p := range parts { + if strings.TrimSpace(p.Text) != "" { + texts = append(texts, p.Text) + } + } + return strings.Join(texts, "\n") + } + return "" } -type ChatResponse struct { - Messages []GatewayMessage `json:"messages"` - Skills []string `json:"skills,omitempty"` - Model string `json:"model,omitempty"` - Provider string `json:"provider,omitempty"` +// ContentParts parses the content as an array of ContentPart. +// Returns nil if the content is a plain string or not parseable. +func (m ModelMessage) ContentParts() []ContentPart { + if len(m.Content) == 0 { + return nil + } + var parts []ContentPart + if err := json.Unmarshal(m.Content, &parts); err != nil { + return nil + } + return parts } -type StreamChunk = json.RawMessage - -type SchedulePayload struct { - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Pattern string `json:"pattern"` - MaxCalls *int `json:"maxCalls,omitempty"` - Command string `json:"command"` +// HasContent reports whether the message carries non-empty content or tool calls. +func (m ModelMessage) HasContent() bool { + if strings.TrimSpace(m.TextContent()) != "" { + return true + } + if len(m.ContentParts()) > 0 { + return true + } + return len(m.ToolCalls) > 0 } -// NormalizedMessage is the internal unified message structure. -type NormalizedMessage struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Parts []ContentPart `json:"parts,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - Name string `json:"name,omitempty"` +// NewTextContent creates a json.RawMessage from a plain string. +func NewTextContent(text string) json.RawMessage { + data, _ := json.Marshal(text) + return data } +// ContentPart represents one element of a multi-part message content. type ContentPart struct { Type string `json:"type"` Text string `json:"text,omitempty"` @@ -71,13 +83,64 @@ type ContentPart struct { Metadata map[string]any `json:"metadata,omitempty"` } +// HasValue reports whether the content part carries a meaningful value. +func (p ContentPart) HasValue() bool { + return strings.TrimSpace(p.Text) != "" || + strings.TrimSpace(p.URL) != "" || + strings.TrimSpace(p.Emoji) != "" +} + +// ToolCall represents a function/tool invocation in an assistant message. type ToolCall struct { ID string `json:"id,omitempty"` Type string `json:"type"` Function ToolCallFunction `json:"function"` } +// ToolCallFunction holds the name and serialized arguments of a tool call. type ToolCallFunction struct { Name string `json:"name"` Arguments string `json:"arguments"` } + +// ChatRequest is the input for Chat and StreamChat. +type ChatRequest struct { + BotID string `json:"-"` + SessionID string `json:"-"` + Token string `json:"-"` + UserID string `json:"-"` + ContainerID string `json:"-"` + ContactID string `json:"-"` + ContactName string `json:"-"` + ContactAlias string `json:"-"` + ReplyTarget string `json:"-"` + SessionToken string `json:"-"` + + Query string `json:"query"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + MaxContextLoadTime int `json:"max_context_load_time,omitempty"` + Language string `json:"language,omitempty"` + Channels []string `json:"channels,omitempty"` + CurrentChannel string `json:"current_channel,omitempty"` + Messages []ModelMessage `json:"messages,omitempty"` + Skills []string `json:"skills,omitempty"` + AllowedActions []string `json:"allowed_actions,omitempty"` +} + +// ChatResponse is the output of a non-streaming chat call. +type ChatResponse struct { + Messages []ModelMessage `json:"messages"` + Skills []string `json:"skills,omitempty"` + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` +} + +// StreamChunk is a raw JSON chunk from the streaming response. +type StreamChunk = json.RawMessage + +// AssistantOutput holds extracted assistant content for downstream consumers. +type AssistantOutput struct { + Content string + Parts []ContentPart +} diff --git a/internal/db/sqlc/containers.sql.go b/internal/db/sqlc/containers.sql.go index 7771f34a..3141781b 100644 --- a/internal/db/sqlc/containers.sql.go +++ b/internal/db/sqlc/containers.sql.go @@ -11,6 +11,15 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const deleteContainerByBotID = `-- name: DeleteContainerByBotID :exec +DELETE FROM containers WHERE bot_id = $1 +` + +func (q *Queries) DeleteContainerByBotID(ctx context.Context, botID pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteContainerByBotID, botID) + return err +} + const getContainerByBotID = `-- name: GetContainerByBotID :one SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE bot_id = $1 ORDER BY updated_at DESC LIMIT 1 ` @@ -63,6 +72,44 @@ func (q *Queries) GetContainerByContainerID(ctx context.Context, containerID str return i, err } +const updateContainerStarted = `-- name: UpdateContainerStarted :exec +UPDATE containers +SET status = 'running', last_started_at = now(), updated_at = now() +WHERE bot_id = $1 +` + +func (q *Queries) UpdateContainerStarted(ctx context.Context, botID pgtype.UUID) error { + _, err := q.db.Exec(ctx, updateContainerStarted, botID) + return err +} + +const updateContainerStatus = `-- name: UpdateContainerStatus :exec +UPDATE containers +SET status = $1, updated_at = now() +WHERE bot_id = $2 +` + +type UpdateContainerStatusParams struct { + Status string `json:"status"` + BotID pgtype.UUID `json:"bot_id"` +} + +func (q *Queries) UpdateContainerStatus(ctx context.Context, arg UpdateContainerStatusParams) error { + _, err := q.db.Exec(ctx, updateContainerStatus, arg.Status, arg.BotID) + return err +} + +const updateContainerStopped = `-- name: UpdateContainerStopped :exec +UPDATE containers +SET status = 'stopped', last_stopped_at = now(), updated_at = now() +WHERE bot_id = $1 +` + +func (q *Queries) UpdateContainerStopped(ctx context.Context, botID pgtype.UUID) error { + _, err := q.db.Exec(ctx, updateContainerStopped, botID) + return err +} + const upsertContainer = `-- name: UpsertContainer :exec INSERT INTO containers ( bot_id, container_id, container_name, image, status, namespace, auto_start, diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index a9e7977b..c6789315 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -44,7 +44,6 @@ type ContainerdHandler struct { } type CreateContainerRequest struct { - ContainerID string `json:"container_id"` Image string `json:"image,omitempty"` Snapshotter string `json:"snapshotter,omitempty"` } @@ -56,8 +55,19 @@ type CreateContainerResponse struct { Started bool `json:"started"` } +type GetContainerResponse struct { + ContainerID string `json:"container_id"` + Image string `json:"image"` + Status string `json:"status"` + Namespace string `json:"namespace"` + HostPath string `json:"host_path,omitempty"` + ContainerPath string `json:"container_path"` + TaskRunning bool `json:"task_running"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type CreateSnapshotRequest struct { - ContainerID string `json:"container_id"` SnapshotName string `json:"snapshot_name"` } @@ -67,20 +77,6 @@ type CreateSnapshotResponse struct { Snapshotter string `json:"snapshotter"` } -type ContainerInfo struct { - ID string `json:"id"` - Image string `json:"image,omitempty"` - Snapshotter string `json:"snapshotter,omitempty"` - SnapshotKey string `json:"snapshot_key,omitempty"` - CreatedAt time.Time `json:"created_at,omitempty"` - UpdatedAt time.Time `json:"updated_at,omitempty"` - Labels map[string]string `json:"labels,omitempty"` -} - -type ListContainersResponse struct { - Containers []ContainerInfo `json:"containers"` -} - type SnapshotInfo struct { Snapshotter string `json:"snapshotter"` Name string `json:"name"` @@ -112,14 +108,16 @@ func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPC func (h *ContainerdHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/container") group.POST("", h.CreateContainer) - group.GET("/list", h.ListContainers) - group.DELETE("/:id", h.DeleteContainer) + group.GET("", h.GetContainer) + group.DELETE("", h.DeleteContainer) + group.POST("/start", h.StartContainer) + group.POST("/stop", h.StopContainer) group.POST("/snapshots", h.CreateSnapshot) group.GET("/snapshots", h.ListSnapshots) group.GET("/skills", h.ListSkills) group.POST("/skills", h.UpsertSkills) group.DELETE("/skills", h.DeleteSkills) - group.POST("/fs/:id", h.HandleMCPFS) + group.POST("/fs", h.HandleMCPFS) } // CreateContainer godoc @@ -141,10 +139,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - req.ContainerID = strings.TrimSpace(req.ContainerID) - if req.ContainerID == "" { - req.ContainerID = "mcp-" + botID - } + containerID := mcp.ContainerPrefix + botID image := strings.TrimSpace(req.Image) if image == "" { @@ -166,6 +161,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { if dataRoot == "" { dataRoot = config.DefaultDataRoot } + dataRoot, _ = filepath.Abs(dataRoot) dataMount := strings.TrimSpace(h.cfg.DataMount) if dataMount == "" { dataMount = config.DefaultDataMount @@ -197,7 +193,7 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } _, err = h.service.CreateContainer(ctx, ctr.CreateContainerRequest{ - ID: req.ContainerID, + ID: containerID, ImageRef: image, Snapshotter: snapshotter, Labels: map[string]string{ @@ -209,7 +205,6 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, "snapshotter="+snapshotter+" image="+image+" err="+err.Error()) } - // Persist container record in database if h.queries != nil { pgBotID, parseErr := parsePgUUID(botID) if parseErr == nil { @@ -219,8 +214,8 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { } _ = h.queries.UpsertContainer(c.Request().Context(), dbsqlc.UpsertContainerParams{ BotID: pgBotID, - ContainerID: req.ContainerID, - ContainerName: req.ContainerID, + ContainerID: containerID, + ContainerName: containerID, Image: image, Status: "created", Namespace: ns, @@ -236,20 +231,25 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - if _, err := h.service.StartTask(c.Request().Context(), req.ContainerID, &ctr.StartTaskOptions{ + if _, err := h.service.StartTask(ctx, containerID, &ctr.StartTaskOptions{ UseStdio: false, FIFODir: fifoDir, }); err == nil { started = true + if h.queries != nil { + if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { + _ = h.queries.UpdateContainerStarted(c.Request().Context(), pgBotID) + } + } } else { h.logger.Error("mcp container start failed", - slog.String("container_id", req.ContainerID), + slog.String("container_id", containerID), slog.Any("error", err), ) } return c.JSON(http.StatusOK, CreateContainerResponse{ - ContainerID: req.ContainerID, + ContainerID: containerID, Image: image, Snapshotter: snapshotter, Started: started, @@ -334,122 +334,212 @@ func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (s return bestID, nil } -// ListContainers godoc -// @Summary List containers for bot +// GetContainer godoc +// @Summary Get container info for bot // @Tags containerd // @Param bot_id path string true "Bot ID" -// @Success 200 {object} ListContainersResponse +// @Success 200 {object} GetContainerResponse +// @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/list [get] -func (h *ContainerdHandler) ListContainers(c echo.Context) error { +// @Router /bots/{bot_id}/container [get] +func (h *ContainerdHandler) GetContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { return err } ctx := c.Request().Context() - containers, err := h.service.ListContainersByLabel(ctx, mcp.BotLabelKey, botID) + + if h.queries != nil { + pgBotID, parseErr := parsePgUUID(botID) + if parseErr == nil { + row, dbErr := h.queries.GetContainerByBotID(ctx, pgBotID) + if dbErr == nil { + taskRunning := h.isTaskRunning(ctx, row.ContainerID) + hostPath := "" + if row.HostPath.Valid { + hostPath = row.HostPath.String + } + createdAt := time.Time{} + if row.CreatedAt.Valid { + createdAt = row.CreatedAt.Time + } + updatedAt := time.Time{} + if row.UpdatedAt.Valid { + updatedAt = row.UpdatedAt.Time + } + return c.JSON(http.StatusOK, GetContainerResponse{ + ContainerID: row.ContainerID, + Image: row.Image, + Status: row.Status, + Namespace: row.Namespace, + HostPath: hostPath, + ContainerPath: row.ContainerPath, + TaskRunning: taskRunning, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }) + } + } + } + + containerID, err := h.botContainerID(ctx, botID) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") } infoCtx := ctx if strings.TrimSpace(h.namespace) != "" { infoCtx = namespaces.WithNamespace(ctx, h.namespace) } - items := make([]ContainerInfo, 0, len(containers)) - for _, container := range containers { - info, err := container.Info(infoCtx) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - items = append(items, ContainerInfo{ - ID: info.ID, - Image: info.Image, - Snapshotter: info.Snapshotter, - SnapshotKey: info.SnapshotKey, - CreatedAt: info.CreatedAt, - UpdatedAt: info.UpdatedAt, - Labels: info.Labels, - }) - } - sort.Slice(items, func(i, j int) bool { - return items[i].ID < items[j].ID - }) - return c.JSON(http.StatusOK, ListContainersResponse{Containers: items}) -} - -// DeleteContainer godoc -// @Summary Delete MCP container -// @Tags containerd -// @Param bot_id path string true "Bot ID" -// @Param id path string true "Container ID" -// @Success 204 -// @Failure 400 {object} ErrorResponse -// @Failure 404 {object} ErrorResponse -// @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/{id} [delete] -func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { - if _, err := h.requireBotAccess(c); err != nil { - return err - } - containerID := strings.TrimSpace(c.Param("id")) - if containerID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "container id is required") - } - _ = h.service.DeleteTask(c.Request().Context(), containerID, &ctr.DeleteTaskOptions{Force: true}) - if err := h.service.DeleteContainer(c.Request().Context(), containerID, &ctr.DeleteContainerOptions{ - CleanupSnapshot: true, - }); err != nil { + container, err := h.service.GetContainer(infoCtx, containerID) + if err != nil { if errdefs.IsNotFound(err) { return echo.NewHTTPError(http.StatusNotFound, "container not found") } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + info, err := container.Info(infoCtx) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, GetContainerResponse{ + ContainerID: info.ID, + Image: info.Image, + Status: "unknown", + Namespace: h.namespace, + TaskRunning: h.isTaskRunning(ctx, containerID), + CreatedAt: info.CreatedAt, + UpdatedAt: info.UpdatedAt, + }) +} + +// DeleteContainer godoc +// @Summary Delete MCP container for bot +// @Tags containerd +// @Param bot_id path string true "Bot ID" +// @Success 204 +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/container [delete] +func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } + if err := h.CleanupBotContainer(c.Request().Context(), botID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } return c.NoContent(http.StatusNoContent) } +// StartContainer godoc +// @Summary Start container task for bot +// @Tags containerd +// @Param bot_id path string true "Bot ID" +// @Success 200 {object} object +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/container/start [post] +func (h *ContainerdHandler) StartContainer(c echo.Context) error { + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } + ctx := c.Request().Context() + containerID, err := h.botContainerID(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") + } + if err := h.ensureTaskRunning(ctx, containerID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if h.queries != nil { + if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { + _ = h.queries.UpdateContainerStarted(ctx, pgBotID) + } + } + return c.JSON(http.StatusOK, map[string]bool{"started": true}) +} + +// StopContainer godoc +// @Summary Stop container task for bot +// @Tags containerd +// @Param bot_id path string true "Bot ID" +// @Success 200 {object} object +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /bots/{bot_id}/container/stop [post] +func (h *ContainerdHandler) StopContainer(c echo.Context) error { + botID, err := h.requireBotAccess(c) + if err != nil { + return err + } + ctx := c.Request().Context() + containerID, err := h.botContainerID(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") + } + if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{ + Timeout: 10 * time.Second, + Force: true, + }); err != nil && !errdefs.IsNotFound(err) { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + _ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}) + if h.queries != nil { + if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { + _ = h.queries.UpdateContainerStopped(ctx, pgBotID) + } + } + return c.JSON(http.StatusOK, map[string]bool{"stopped": true}) +} + // CreateSnapshot godoc -// @Summary Create container snapshot +// @Summary Create container snapshot for bot // @Tags containerd // @Param bot_id path string true "Bot ID" // @Param payload body CreateSnapshotRequest true "Create snapshot payload" // @Success 200 {object} CreateSnapshotResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/container/snapshots [post] func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { - if _, err := h.requireBotAccess(c); err != nil { + botID, err := h.requireBotAccess(c) + if err != nil { return err } var req CreateSnapshotRequest if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - if strings.TrimSpace(req.ContainerID) == "" { - return echo.NewHTTPError(http.StatusBadRequest, "container_id is required") + ctx := c.Request().Context() + containerID, err := h.botContainerID(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") } - container, err := h.service.GetContainer(c.Request().Context(), req.ContainerID) + container, err := h.service.GetContainer(ctx, containerID) if err != nil { if errdefs.IsNotFound(err) { return echo.NewHTTPError(http.StatusNotFound, "container not found") } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - ctx := c.Request().Context() + infoCtx := ctx if strings.TrimSpace(h.namespace) != "" { - ctx = namespaces.WithNamespace(ctx, h.namespace) + infoCtx = namespaces.WithNamespace(ctx, h.namespace) } - info, err := container.Info(ctx) + info, err := container.Info(infoCtx) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } snapshotName := strings.TrimSpace(req.SnapshotName) if snapshotName == "" { - snapshotName = req.ContainerID + "-" + time.Now().Format("20060102150405") + snapshotName = containerID + "-" + time.Now().Format("20060102150405") } - if err := h.service.CommitSnapshot(c.Request().Context(), info.Snapshotter, snapshotName, info.SnapshotKey); err != nil { + if err := h.service.CommitSnapshot(ctx, info.Snapshotter, snapshotName, info.SnapshotKey); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - return c.JSON(http.StatusOK, CreateSnapshotResponse{ - ContainerID: req.ContainerID, + ContainerID: containerID, SnapshotName: snapshotName, Snapshotter: info.Snapshotter, }) @@ -551,6 +641,45 @@ func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, actorID, bot return bot, nil } +// CleanupBotContainer removes the containerd container and DB record for a bot. +func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID string) error { + containerID, err := h.botContainerID(ctx, botID) + if err != nil { + if h.queries != nil { + if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { + _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) + } + } + return nil + } + + _ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{ + Timeout: 5 * time.Second, + Force: true, + }) + _ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}) + + if err := h.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{ + CleanupSnapshot: true, + }); err != nil && !errdefs.IsNotFound(err) { + return err + } + + if h.queries != nil { + if pgBotID, parseErr := parsePgUUID(botID); parseErr == nil { + _ = h.queries.DeleteContainerByBotID(ctx, pgBotID) + } + } + return nil +} + +func (h *ContainerdHandler) isTaskRunning(ctx context.Context, containerID string) bool { + tasks, err := h.service.ListTasks(ctx, &ctr.ListTasksOptions{ + Filter: "container.id==" + containerID, + }) + return err == nil && len(tasks) > 0 && tasks[0].Status == tasktypes.Status_RUNNING +} + func parsePgUUID(id string) (pgtype.UUID, error) { parsed, err := uuid.Parse(strings.TrimSpace(id)) if err != nil { diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go index 94a0dcb5..6425a569 100644 --- a/internal/handlers/fs.go +++ b/internal/handlers/fs.go @@ -50,9 +50,10 @@ func (h *ContainerdHandler) HandleMCPFS(c echo.Context) error { if err != nil { return err } - containerID := strings.TrimSpace(c.Param("id")) - if containerID == "" { - return echo.NewHTTPError(http.StatusBadRequest, "container id is required") + ctx := c.Request().Context() + containerID, err := h.botContainerID(ctx, botID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, "container not found for bot") } var req mcptools.JSONRPCRequest diff --git a/internal/router/channel.go b/internal/router/channel.go index 184d5ef2..74c97127 100644 --- a/internal/router/channel.go +++ b/internal/router/channel.go @@ -352,29 +352,23 @@ type sendMessageToolArgs struct { Message *channel.Message `json:"message"` } -type toolCall struct { - Name string - Arguments string -} - -func collectMessageToolContext(messages []chat.GatewayMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { +func collectMessageToolContext(messages []chat.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { if len(messages) == 0 { return nil, false } - sentTexts := make([]string, 0) + var sentTexts []string suppressReplies := false for _, msg := range messages { - for _, call := range extractToolCalls(msg) { - if call.Name != "send_message" { + for _, tc := range msg.ToolCalls { + if tc.Function.Name != "send_message" { continue } var args sendMessageToolArgs - if !parseToolArguments(call.Arguments, &args) { + if !parseToolArguments(tc.Function.Arguments, &args) { continue } - messageText := strings.TrimSpace(extractSendMessageText(args)) - if messageText != "" { - sentTexts = append(sentTexts, messageText) + if text := strings.TrimSpace(extractSendMessageText(args)); text != "" { + sentTexts = append(sentTexts, text) } if shouldSuppressForToolCall(args, channelType, replyTarget) { suppressReplies = true @@ -384,60 +378,6 @@ func collectMessageToolContext(messages []chat.GatewayMessage, channelType chann return sentTexts, suppressReplies } -func extractToolCalls(msg chat.GatewayMessage) []toolCall { - calls := make([]toolCall, 0) - if msg == nil { - return calls - } - if rawCalls, ok := msg["tool_calls"].([]any); ok { - for _, raw := range rawCalls { - call, ok := raw.(map[string]any) - if !ok { - continue - } - name, args := parseToolCall(call) - if name == "" { - continue - } - calls = append(calls, toolCall{Name: name, Arguments: args}) - } - } - if fn, ok := msg["function_call"].(map[string]any); ok { - name := readString(fn["name"]) - args := readString(fn["arguments"]) - if name != "" { - calls = append(calls, toolCall{Name: name, Arguments: args}) - } - } - if fn, ok := msg["functionCall"].(map[string]any); ok { - name := readString(fn["name"]) - args := readString(fn["arguments"]) - if name != "" { - calls = append(calls, toolCall{Name: name, Arguments: args}) - } - } - return calls -} - -func parseToolCall(call map[string]any) (string, string) { - if call == nil { - return "", "" - } - name := "" - args := "" - if fn, ok := call["function"].(map[string]any); ok { - name = readString(fn["name"]) - args = readString(fn["arguments"]) - } - if name == "" { - name = readString(call["name"]) - } - if args == "" { - args = readString(call["arguments"]) - } - return name, args -} - func parseToolArguments(raw string, out any) bool { if strings.TrimSpace(raw) == "" { return false @@ -578,12 +518,6 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool { return false } -func readString(value any) string { - if raw, ok := value.(string); ok { - return raw - } - return "" -} func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { if state, ok := IdentityStateFromContext(ctx); ok { diff --git a/internal/router/channel_test.go b/internal/router/channel_test.go index b65daf9a..28c35631 100644 --- a/internal/router/channel_test.go +++ b/internal/router/channel_test.go @@ -129,8 +129,8 @@ func TestChannelInboundProcessorBoundUser(t *testing.T) { } gateway := &fakeChatGateway{ resp: chat.ChatResponse{ - Messages: []chat.GatewayMessage{ - {"role": "assistant", "content": "AI回复内容"}, + Messages: []chat.ModelMessage{ + {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, }, }, } @@ -218,8 +218,8 @@ func TestChannelInboundProcessorSilentReply(t *testing.T) { } gateway := &fakeChatGateway{ resp: chat.ChatResponse{ - Messages: []chat.GatewayMessage{ - {"role": "assistant", "content": "NO_REPLY"}, + Messages: []chat.ModelMessage{ + {Role: "assistant", Content: chat.NewTextContent("NO_REPLY")}, }, }, } @@ -255,20 +255,20 @@ func TestChannelInboundProcessorSuppressOnToolSend(t *testing.T) { } gateway := &fakeChatGateway{ resp: chat.ChatResponse{ - Messages: []chat.GatewayMessage{ + Messages: []chat.ModelMessage{ { - "role": "assistant", - "tool_calls": []any{ - map[string]any{ - "type": "function", - "function": map[string]any{ - "name": "send_message", - "arguments": `{"platform":"feishu","target":"target-id","message":{"text":"AI回复内容"}}`, + Role: "assistant", + ToolCalls: []chat.ToolCall{ + { + Type: "function", + Function: chat.ToolCallFunction{ + Name: "send_message", + Arguments: `{"platform":"feishu","target":"target-id","message":{"text":"AI回复内容"}}`, }, }, }, }, - {"role": "assistant", "content": "AI回复内容"}, + {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, }, }, } @@ -304,20 +304,20 @@ func TestChannelInboundProcessorDedupeWithToolSend(t *testing.T) { } gateway := &fakeChatGateway{ resp: chat.ChatResponse{ - Messages: []chat.GatewayMessage{ + Messages: []chat.ModelMessage{ { - "role": "assistant", - "tool_calls": []any{ - map[string]any{ - "type": "function", - "function": map[string]any{ - "name": "send_message", - "arguments": `{"platform":"feishu","target":"other-target","message":{"text":"AI回复内容"}}`, + Role: "assistant", + ToolCalls: []chat.ToolCall{ + { + Type: "function", + Function: chat.ToolCallFunction{ + Name: "send_message", + Arguments: `{"platform":"feishu","target":"other-target","message":{"text":"AI回复内容"}}`, }, }, }, }, - {"role": "assistant", "content": "AI回复内容"}, + {Role: "assistant", Content: chat.NewTextContent("AI回复内容")}, }, }, } diff --git a/mise.toml b/mise.toml index 3bec6710..abb5ec83 100644 --- a/mise.toml +++ b/mise.toml @@ -10,7 +10,7 @@ bun = "latest" # pnpm for workspace management pnpm = "10" # Lima for macOS -lima = { version = "latest", platform = "darwin" } +lima = { version = "system", platform = "darwin" } [task_config] dir = "{{cwd}}"