mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
refactor(core): codebase quality cleanup
- Remove user-level model settings (chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language) from users table - Merge migration 0002 into 0001, remove compatibility migrations - Delete dead conversation/resolver.go (1177 lines, only flow/resolver.go used) - Remove type aliases (Chat=Conversation, types_alias.go) - Fix SQL: remove AND false stub, fix UpdateChatTitle model_id, reset model IDs in DeleteSettings, add preauth expiry filter, add ListMessages limit, remove 10 dead queries - Extract shared handler helpers (RequireChannelIdentityID, AuthorizeBotAccess) - Rename internal/router to internal/channel/inbound - Fix identity confusion: remove UserID->ChannelIdentityID fallbacks - Fix all _ = var patterns with proper error logging - Fix error propagation: storeMessages, rescheduleJob, botContainerID - Fix naming: ModelId->ModelID, active->is_active, Duration semantic fix - Remove dead code: mcpService, ReplyTarget, callMCPServer, sshShellQuote, buildSessionMetadata, ChatRequest.Language, TriggerPayload.ChatID - Fix code quality: errors.Is(), remove goto, CreateHuman deprecated - Remove Enable model endpoint and user-level settings CLI commands - Regenerate sqlc, swagger, SDK
This commit is contained in:
+6
-6
@@ -50,7 +50,7 @@ import (
|
||||
"github.com/memohai/memoh/internal/policy"
|
||||
"github.com/memohai/memoh/internal/preauth"
|
||||
"github.com/memohai/memoh/internal/providers"
|
||||
"github.com/memohai/memoh/internal/router"
|
||||
"github.com/memohai/memoh/internal/channel/inbound"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
"github.com/memohai/memoh/internal/server"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
@@ -305,8 +305,8 @@ func provideScheduleTriggerer(resolver *flow.Resolver) schedule.Triggerer {
|
||||
// conversation flow
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, mcpConnService *mcp.ConnectionService, containerdHandler *handlers.ContainerdHandler) *flow.Resolver {
|
||||
resolver := flow.NewResolver(log, modelsService, queries, memoryService, chatService, msgService, settingsService, mcpConnService, cfg.AgentGateway.BaseURL(), 120*time.Second)
|
||||
func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, containerdHandler *handlers.ContainerdHandler) *flow.Resolver {
|
||||
resolver := flow.NewResolver(log, modelsService, queries, memoryService, chatService, msgService, settingsService, cfg.AgentGateway.BaseURL(), 120*time.Second)
|
||||
resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler})
|
||||
return resolver
|
||||
}
|
||||
@@ -324,11 +324,11 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub) *channel.Regi
|
||||
return registry
|
||||
}
|
||||
|
||||
func provideChannelRouter(log *slog.Logger, registry *channel.Registry, routeService *route.DBService, msgService *message.DBService, resolver *flow.Resolver, identityService *identities.Service, botService *bots.Service, policyService *policy.Service, preauthService *preauth.Service, bindService *bind.Service, rc *boot.RuntimeConfig) *router.ChannelInboundProcessor {
|
||||
return router.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, botService, policyService, preauthService, bindService, rc.JwtSecret, 5*time.Minute)
|
||||
func provideChannelRouter(log *slog.Logger, registry *channel.Registry, routeService *route.DBService, msgService *message.DBService, resolver *flow.Resolver, identityService *identities.Service, botService *bots.Service, policyService *policy.Service, preauthService *preauth.Service, bindService *bind.Service, rc *boot.RuntimeConfig) *inbound.ChannelInboundProcessor {
|
||||
return inbound.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, botService, policyService, preauthService, bindService, rc.JwtSecret, 5*time.Minute)
|
||||
}
|
||||
|
||||
func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, channelRouter *router.ChannelInboundProcessor) *channel.Manager {
|
||||
func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, channelRouter *inbound.ChannelInboundProcessor) *channel.Manager {
|
||||
mgr := channel.NewManager(log, registry, channelService, channelRouter)
|
||||
if mw := channelRouter.IdentityMiddleware(); mw != nil {
|
||||
mgr.Use(mw)
|
||||
|
||||
@@ -19,11 +19,6 @@ CREATE TABLE IF NOT EXISTS users (
|
||||
avatar_url TEXT,
|
||||
data_root TEXT,
|
||||
last_login_at TIMESTAMPTZ,
|
||||
chat_model_id TEXT,
|
||||
memory_model_id TEXT,
|
||||
embedding_model_id TEXT,
|
||||
max_context_load_time INTEGER NOT NULL DEFAULT 1440,
|
||||
language TEXT NOT NULL DEFAULT 'auto',
|
||||
is_active BOOLEAN NOT NULL DEFAULT true,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
@@ -39,6 +34,7 @@ CREATE TABLE IF NOT EXISTS channel_identities (
|
||||
channel_type TEXT NOT NULL,
|
||||
channel_subject_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
avatar_url TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
-- 0002_channel_identity_avatar (down)
|
||||
ALTER TABLE channel_identities DROP COLUMN IF EXISTS avatar_url;
|
||||
@@ -1,3 +0,0 @@
|
||||
-- 0002_channel_identity_avatar
|
||||
-- Add avatar_url column to channel_identities for sender profile display.
|
||||
ALTER TABLE channel_identities ADD COLUMN IF NOT EXISTS avatar_url TEXT;
|
||||
@@ -43,8 +43,3 @@ SET user_id = $2, updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id, user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata, created_at, updated_at;
|
||||
|
||||
-- name: ClearChannelIdentityLinkedUser :one
|
||||
UPDATE channel_identities
|
||||
SET user_id = NULL, updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id, user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata, created_at, updated_at;
|
||||
|
||||
@@ -29,9 +29,6 @@ ON CONFLICT (container_id) DO UPDATE SET
|
||||
last_stopped_at = EXCLUDED.last_stopped_at,
|
||||
updated_at = now();
|
||||
|
||||
-- name: GetContainerByContainerID :one
|
||||
SELECT * FROM containers WHERE container_id = sqlc.arg(container_id);
|
||||
|
||||
-- name: GetContainerByBotID :one
|
||||
SELECT * FROM containers WHERE bot_id = sqlc.arg(bot_id) ORDER BY updated_at DESC LIMIT 1;
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ SELECT
|
||||
b.display_name AS title,
|
||||
b.owner_user_id AS created_by_user_id,
|
||||
b.metadata AS metadata,
|
||||
chat_models.model_id AS model_id,
|
||||
b.created_at,
|
||||
b.updated_at,
|
||||
'participant'::text AS access_mode,
|
||||
@@ -69,6 +70,7 @@ SELECT
|
||||
NULL::timestamptz AS last_observed_at
|
||||
FROM bots b
|
||||
LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id)
|
||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||
WHERE b.id = sqlc.arg(bot_id)
|
||||
AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL)
|
||||
ORDER BY b.updated_at DESC;
|
||||
@@ -102,25 +104,29 @@ SELECT
|
||||
FROM bots b
|
||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||
WHERE b.id = $1
|
||||
AND false
|
||||
ORDER BY b.created_at DESC;
|
||||
|
||||
-- name: UpdateChatTitle :one
|
||||
UPDATE bots
|
||||
SET display_name = sqlc.arg(title),
|
||||
updated_at = now()
|
||||
WHERE id = sqlc.arg(id)
|
||||
RETURNING
|
||||
id,
|
||||
id AS bot_id,
|
||||
CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind,
|
||||
WITH updated AS (
|
||||
UPDATE bots
|
||||
SET display_name = sqlc.arg(title),
|
||||
updated_at = now()
|
||||
WHERE bots.id = sqlc.arg(bot_id)
|
||||
RETURNING *
|
||||
)
|
||||
SELECT
|
||||
updated.id AS id,
|
||||
updated.id AS bot_id,
|
||||
CASE WHEN updated.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
|
||||
NULL::uuid AS parent_chat_id,
|
||||
display_name AS title,
|
||||
owner_user_id AS created_by_user_id,
|
||||
metadata,
|
||||
NULL::text AS model_id,
|
||||
created_at,
|
||||
updated_at;
|
||||
updated.display_name AS title,
|
||||
updated.owner_user_id AS created_by_user_id,
|
||||
updated.metadata,
|
||||
chat_models.model_id AS model_id,
|
||||
updated.created_at,
|
||||
updated.updated_at
|
||||
FROM updated
|
||||
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
|
||||
|
||||
-- name: TouchChat :exec
|
||||
UPDATE bots
|
||||
|
||||
@@ -6,6 +6,3 @@ VALUES (
|
||||
sqlc.arg(event_type),
|
||||
sqlc.arg(payload)
|
||||
);
|
||||
|
||||
-- name: ListLifecycleEventsByContainerID :many
|
||||
SELECT * FROM lifecycle_events WHERE container_id = sqlc.arg(container_id) ORDER BY created_at ASC;
|
||||
|
||||
@@ -56,7 +56,8 @@ SELECT
|
||||
FROM bot_history_messages m
|
||||
LEFT JOIN channel_identities ci ON ci.id = m.sender_channel_identity_id
|
||||
WHERE m.bot_id = sqlc.arg(bot_id)
|
||||
ORDER BY m.created_at ASC;
|
||||
ORDER BY m.created_at ASC
|
||||
LIMIT 10000;
|
||||
|
||||
-- name: ListMessagesSince :many
|
||||
SELECT
|
||||
|
||||
@@ -136,29 +136,9 @@ VALUES (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetModelVariantByID :one
|
||||
SELECT * FROM model_variants WHERE id = sqlc.arg(id);
|
||||
|
||||
-- name: ListModelVariantsByModelUUID :many
|
||||
SELECT * FROM model_variants
|
||||
WHERE model_uuid = sqlc.arg(model_uuid)
|
||||
ORDER BY weight DESC, created_at DESC;
|
||||
|
||||
-- name: ListModelVariantsByVariantID :many
|
||||
SELECT * FROM model_variants
|
||||
WHERE variant_id = sqlc.arg(variant_id)
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: UpdateModelVariant :one
|
||||
UPDATE model_variants
|
||||
SET
|
||||
variant_id = sqlc.arg(variant_id),
|
||||
weight = sqlc.arg(weight),
|
||||
metadata = sqlc.arg(metadata),
|
||||
updated_at = now()
|
||||
WHERE id = sqlc.arg(id)
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteModelVariant :exec
|
||||
DELETE FROM model_variants WHERE id = sqlc.arg(id);
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ RETURNING id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at;
|
||||
SELECT id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at
|
||||
FROM bot_preauth_keys
|
||||
WHERE token = $1
|
||||
AND used_at IS NULL
|
||||
AND (expires_at IS NULL OR expires_at > now())
|
||||
LIMIT 1;
|
||||
|
||||
-- name: MarkBotPreauthKeyUsed :one
|
||||
|
||||
+3
-16
@@ -1,19 +1,3 @@
|
||||
-- name: GetSettingsByUserID :one
|
||||
SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
|
||||
FROM users
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: UpsertUserSettings :one
|
||||
UPDATE users
|
||||
SET chat_model_id = $2,
|
||||
memory_model_id = $3,
|
||||
embedding_model_id = $4,
|
||||
max_context_load_time = $5,
|
||||
language = $6,
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language;
|
||||
|
||||
-- name: GetSettingsByBotID :one
|
||||
SELECT
|
||||
bots.id AS bot_id,
|
||||
@@ -60,5 +44,8 @@ UPDATE bots
|
||||
SET max_context_load_time = 1440,
|
||||
language = 'auto',
|
||||
allow_guest = false,
|
||||
chat_model_id = NULL,
|
||||
memory_model_id = NULL,
|
||||
embedding_model_id = NULL,
|
||||
updated_at = now()
|
||||
WHERE id = $1;
|
||||
|
||||
@@ -8,6 +8,3 @@ VALUES (
|
||||
sqlc.arg(digest)
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
-- name: ListSnapshotsByContainerID :many
|
||||
SELECT * FROM snapshots WHERE container_id = sqlc.arg(container_id) ORDER BY created_at ASC;
|
||||
|
||||
@@ -8,13 +8,6 @@ SELECT *
|
||||
FROM users
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: UpdateUserStatus :one
|
||||
UPDATE users
|
||||
SET is_active = $2,
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING *;
|
||||
|
||||
-- name: CreateAccount :one
|
||||
UPDATE users
|
||||
SET username = sqlc.arg(username),
|
||||
@@ -54,9 +47,6 @@ ON CONFLICT (username) DO UPDATE SET
|
||||
updated_at = now()
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetAccountByUsername :one
|
||||
SELECT * FROM users WHERE username = sqlc.arg(username);
|
||||
|
||||
-- name: GetAccountByIdentity :one
|
||||
SELECT * FROM users WHERE username = sqlc.arg(identity) OR email = sqlc.arg(identity);
|
||||
|
||||
|
||||
@@ -189,6 +189,8 @@ func (s *Service) Create(ctx context.Context, userID string, req CreateAccountRe
|
||||
}
|
||||
|
||||
// CreateHuman keeps compatibility with older call sites.
|
||||
//
|
||||
// Deprecated: use Create directly.
|
||||
func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
@@ -223,7 +225,7 @@ func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAcco
|
||||
if err != nil {
|
||||
return Account{}, err
|
||||
}
|
||||
role := fmt.Sprint(existing.Role)
|
||||
role := existing.Role
|
||||
if req.Role != nil {
|
||||
role, err = normalizeRole(*req.Role)
|
||||
if err != nil {
|
||||
@@ -412,7 +414,7 @@ func toAccount(row sqlc.User) Account {
|
||||
ID: row.ID.String(),
|
||||
Username: username,
|
||||
Email: email,
|
||||
Role: fmt.Sprint(row.Role),
|
||||
Role: row.Role,
|
||||
DisplayName: displayName,
|
||||
AvatarURL: avatarURL,
|
||||
IsActive: row.IsActive,
|
||||
|
||||
@@ -419,7 +419,9 @@ func (s *Service) enqueueDeleteLifecycle(botID string) {
|
||||
slog.String("bot_id", botID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
_ = s.updateStatus(ctx, botID, BotStatusReady)
|
||||
if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil {
|
||||
s.logger.Error("revert bot status failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := s.queries.DeleteBotByID(ctx, botUUID); err != nil {
|
||||
@@ -427,7 +429,9 @@ func (s *Service) enqueueDeleteLifecycle(botID string) {
|
||||
slog.String("bot_id", botID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
_ = s.updateStatus(ctx, botID, BotStatusReady)
|
||||
if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil {
|
||||
s.logger.Error("revert bot status failed", slog.String("bot_id", botID), slog.Any("error", err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -3,6 +3,7 @@ package feishu
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -26,7 +27,9 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1, botOpenID string) ch
|
||||
|
||||
var contentMap map[string]any
|
||||
if message.Content != nil {
|
||||
_ = json.Unmarshal([]byte(*message.Content), &contentMap)
|
||||
if err := json.Unmarshal([]byte(*message.Content), &contentMap); err != nil {
|
||||
slog.Warn("feishu inbound: unmarshal content failed", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
isMentioned := isFeishuBotMentioned(contentMap, message.Mentions, botOpenID)
|
||||
|
||||
|
||||
@@ -134,8 +134,12 @@ func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamE
|
||||
finalText := strings.TrimSpace(s.buf.String())
|
||||
s.mu.Unlock()
|
||||
if finalText != "" {
|
||||
_ = s.ensureStreamMessage(ctx, finalText)
|
||||
_ = s.editStreamMessage(ctx, finalText)
|
||||
if err := s.ensureStreamMessage(ctx, finalText); err != nil {
|
||||
slog.Warn("telegram: ensure stream message failed", slog.Any("error", err))
|
||||
}
|
||||
if err := s.editStreamMessage(ctx, finalText); err != nil {
|
||||
slog.Warn("telegram: edit stream message failed", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -264,7 +264,9 @@ func (s *Service) LinkChannelIdentityToUser(ctx context.Context, channelIdentity
|
||||
func toChannelIdentity(row sqlc.ChannelIdentity) ChannelIdentity {
|
||||
var metadata map[string]any
|
||||
if len(row.Metadata) > 0 {
|
||||
_ = json.Unmarshal(row.Metadata, &metadata)
|
||||
if err := json.Unmarshal(row.Metadata, &metadata); err != nil {
|
||||
slog.Warn("unmarshal channel identity metadata failed", slog.String("id", row.ID.String()), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
if metadata == nil {
|
||||
metadata = map[string]any{}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package router
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -201,7 +201,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
|
||||
var desc channel.Descriptor
|
||||
if p.registry != nil {
|
||||
desc, _ = p.registry.GetDescriptor(msg.Channel)
|
||||
desc, _ = p.registry.GetDescriptor(msg.Channel) //nolint:errcheck // descriptor lookup is best-effort
|
||||
}
|
||||
statusInfo := channel.ProcessingStatusInfo{
|
||||
BotID: identity.BotID,
|
||||
@@ -269,7 +269,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
|
||||
return err
|
||||
}
|
||||
|
||||
chunkCh, streamErrCh := p.runner.StreamChat(ctx, flow.ChatRequest{
|
||||
chunkCh, streamErrCh := p.runner.StreamChat(ctx, conversation.ChatRequest{
|
||||
BotID: identity.BotID,
|
||||
ChatID: activeChatID,
|
||||
Token: token,
|
||||
@@ -1,4 +1,4 @@
|
||||
package router
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package router
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -568,23 +568,3 @@ func (r *IdentityResolver) tryLinkConfiglessChannelIdentityToUser(ctx context.Co
|
||||
}
|
||||
return candidateUserID
|
||||
}
|
||||
|
||||
func buildSessionMetadata(msg channel.InboundMessage) map[string]any {
|
||||
metadata := map[string]any{}
|
||||
if strings.TrimSpace(msg.Source) != "" {
|
||||
metadata["source"] = strings.TrimSpace(msg.Source)
|
||||
}
|
||||
if strings.TrimSpace(msg.Message.ID) != "" {
|
||||
metadata["message_id"] = strings.TrimSpace(msg.Message.ID)
|
||||
}
|
||||
if strings.TrimSpace(msg.Conversation.Type) != "" {
|
||||
metadata["conversation_type"] = strings.TrimSpace(msg.Conversation.Type)
|
||||
}
|
||||
if strings.TrimSpace(msg.Conversation.Name) != "" {
|
||||
metadata["conversation_name"] = strings.TrimSpace(msg.Conversation.Name)
|
||||
}
|
||||
if !msg.ReceivedAt.IsZero() {
|
||||
metadata["received_at"] = msg.ReceivedAt.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package router
|
||||
package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
// ConversationService contains the minimal conversation behavior required by route resolution.
|
||||
type ConversationService interface {
|
||||
Create(ctx context.Context, botID, channelIdentityID string, req conversation.CreateRequest) (conversation.Chat, error)
|
||||
Create(ctx context.Context, botID, channelIdentityID string, req conversation.CreateRequest) (conversation.Conversation, error)
|
||||
IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error)
|
||||
AddParticipant(ctx context.Context, conversationID, channelIdentityID, role string) (conversation.Participant, error)
|
||||
}
|
||||
@@ -250,11 +250,12 @@ func (s *DBService) resolveConversationCreatorChannelIdentityID(ctx context.Cont
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
ownerChannelIdentityID := row.OwnerUserID.String()
|
||||
if strings.TrimSpace(ownerChannelIdentityID) == "" {
|
||||
// NOTE: OwnerUserID is the bot owner's user ID. Used as fallback creator for group conversations.
|
||||
ownerUserID := row.OwnerUserID.String()
|
||||
if strings.TrimSpace(ownerUserID) == "" {
|
||||
return fallback
|
||||
}
|
||||
return ownerChannelIdentityID
|
||||
return ownerUserID
|
||||
}
|
||||
|
||||
func toRouteFromCreate(row sqlc.CreateChatRouteRow) Route {
|
||||
@@ -357,6 +358,8 @@ func parseJSONMap(data []byte) map[string]any {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(data, &m)
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err))
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package flow
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
// ExtractAssistantOutputs collects assistant-role outputs from a slice of ModelMessages.
|
||||
func ExtractAssistantOutputs(messages []ModelMessage) []AssistantOutput {
|
||||
func ExtractAssistantOutputs(messages []conversation.ModelMessage) []conversation.AssistantOutput {
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
outputs := make([]AssistantOutput, 0, len(messages))
|
||||
outputs := make([]conversation.AssistantOutput, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
if msg.Role != "assistant" {
|
||||
continue
|
||||
@@ -17,16 +21,16 @@ func ExtractAssistantOutputs(messages []ModelMessage) []AssistantOutput {
|
||||
if content == "" && len(parts) == 0 {
|
||||
continue
|
||||
}
|
||||
outputs = append(outputs, AssistantOutput{Content: content, Parts: parts})
|
||||
outputs = append(outputs, conversation.AssistantOutput{Content: content, Parts: parts})
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
|
||||
func filterContentParts(parts []ContentPart) []ContentPart {
|
||||
func filterContentParts(parts []conversation.ContentPart) []conversation.ContentPart {
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
filtered := make([]ContentPart, 0, len(parts))
|
||||
filtered := make([]conversation.ContentPart, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if p.HasValue() {
|
||||
filtered = append(filtered, p)
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
"github.com/memohai/memoh/internal/memory"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
@@ -60,7 +59,6 @@ type Resolver struct {
|
||||
conversationSvc ConversationSettingsReader
|
||||
messageService messagepkg.Service
|
||||
settingsService *settings.Service
|
||||
mcpService *mcp.ConnectionService
|
||||
skillLoader SkillLoader
|
||||
gatewayBaseURL string
|
||||
timeout time.Duration
|
||||
@@ -78,7 +76,6 @@ func NewResolver(
|
||||
conversationSvc ConversationSettingsReader,
|
||||
messageService messagepkg.Service,
|
||||
settingsService *settings.Service,
|
||||
mcpService *mcp.ConnectionService,
|
||||
gatewayBaseURL string,
|
||||
timeout time.Duration,
|
||||
) *Resolver {
|
||||
@@ -96,7 +93,6 @@ func NewResolver(
|
||||
conversationSvc: conversationSvc,
|
||||
messageService: messageService,
|
||||
settingsService: settingsService,
|
||||
mcpService: mcpService,
|
||||
gatewayBaseURL: gatewayBaseURL,
|
||||
timeout: timeout,
|
||||
logger: log.With(slog.String("service", "conversation_resolver")),
|
||||
@@ -126,7 +122,6 @@ type gatewayIdentity struct {
|
||||
ChannelIdentityID string `json:"channelIdentityId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
CurrentPlatform string `json:"currentPlatform,omitempty"`
|
||||
ReplyTarget string `json:"replyTarget,omitempty"`
|
||||
SessionToken string `json:"sessionToken,omitempty"`
|
||||
}
|
||||
|
||||
@@ -138,22 +133,22 @@ type gatewaySkill struct {
|
||||
}
|
||||
|
||||
type gatewayRequest struct {
|
||||
Model gatewayModelConfig `json:"model"`
|
||||
ActiveContextTime int `json:"activeContextTime"`
|
||||
Channels []string `json:"channels"`
|
||||
CurrentChannel string `json:"currentChannel"`
|
||||
AllowedActions []string `json:"allowedActions,omitempty"`
|
||||
Messages []ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
UsableSkills []gatewaySkill `json:"usableSkills"`
|
||||
Query string `json:"query"`
|
||||
Identity gatewayIdentity `json:"identity"`
|
||||
Attachments []any `json:"attachments"`
|
||||
Model gatewayModelConfig `json:"model"`
|
||||
ActiveContextTime int `json:"activeContextTime"`
|
||||
Channels []string `json:"channels"`
|
||||
CurrentChannel string `json:"currentChannel"`
|
||||
AllowedActions []string `json:"allowedActions,omitempty"`
|
||||
Messages []conversation.ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
UsableSkills []gatewaySkill `json:"usableSkills"`
|
||||
Query string `json:"query,omitempty"`
|
||||
Identity gatewayIdentity `json:"identity"`
|
||||
Attachments []any `json:"attachments"`
|
||||
}
|
||||
|
||||
type gatewayResponse struct {
|
||||
Messages []ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Messages []conversation.ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
}
|
||||
|
||||
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
|
||||
@@ -168,17 +163,8 @@ type gatewaySchedule struct {
|
||||
|
||||
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
|
||||
type triggerScheduleRequest struct {
|
||||
Model gatewayModelConfig `json:"model"`
|
||||
ActiveContextTime int `json:"activeContextTime"`
|
||||
Channels []string `json:"channels"`
|
||||
CurrentChannel string `json:"currentChannel"`
|
||||
AllowedActions []string `json:"allowedActions,omitempty"`
|
||||
Messages []ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
UsableSkills []gatewaySkill `json:"usableSkills"`
|
||||
Identity gatewayIdentity `json:"identity"`
|
||||
Attachments []any `json:"attachments"`
|
||||
Schedule gatewaySchedule `json:"schedule"`
|
||||
gatewayRequest
|
||||
Schedule gatewaySchedule `json:"schedule"`
|
||||
}
|
||||
|
||||
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
|
||||
@@ -189,7 +175,7 @@ type resolvedContext struct {
|
||||
provider sqlc.LlmProvider
|
||||
}
|
||||
|
||||
func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContext, error) {
|
||||
func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) {
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
return resolvedContext{}, fmt.Errorf("query is required")
|
||||
}
|
||||
@@ -208,7 +194,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
}
|
||||
|
||||
// Check chat-level model override.
|
||||
var chatSettings Settings
|
||||
var chatSettings conversation.Settings
|
||||
if r.conversationSvc != nil {
|
||||
chatSettings, err = r.conversationSvc.GetSettings(ctx, req.ChatID)
|
||||
if err != nil {
|
||||
@@ -216,11 +202,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
}
|
||||
}
|
||||
|
||||
userSettings, err := r.loadUserSettings(ctx, req.UserID)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings)
|
||||
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, chatSettings)
|
||||
if err != nil {
|
||||
return resolvedContext{}, err
|
||||
}
|
||||
@@ -230,7 +212,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
}
|
||||
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
|
||||
|
||||
var messages []ModelMessage
|
||||
var messages []conversation.ModelMessage
|
||||
if !skipHistory && r.conversationSvc != nil {
|
||||
messages, err = r.loadMessages(ctx, req.ChatID, maxCtx)
|
||||
if err != nil {
|
||||
@@ -285,10 +267,9 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
Identity: gatewayIdentity{
|
||||
BotID: req.BotID,
|
||||
ContainerID: containerID,
|
||||
ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID),
|
||||
ChannelIdentityID: strings.TrimSpace(req.SourceChannelIdentityID),
|
||||
DisplayName: r.resolveDisplayName(ctx, req),
|
||||
CurrentPlatform: req.CurrentChannel,
|
||||
ReplyTarget: "",
|
||||
SessionToken: req.ChatToken,
|
||||
},
|
||||
Attachments: []any{},
|
||||
@@ -300,19 +281,19 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
|
||||
// --- Chat ---
|
||||
|
||||
// Chat sends a synchronous chat request to the agent gateway and stores the result.
|
||||
func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) {
|
||||
func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) {
|
||||
rc, err := r.resolve(ctx, req)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
return conversation.ChatResponse{}, err
|
||||
}
|
||||
resp, err := r.postChat(ctx, rc.payload, req.Token)
|
||||
if err != nil {
|
||||
return ChatResponse{}, err
|
||||
return conversation.ChatResponse{}, err
|
||||
}
|
||||
if err := r.storeRound(ctx, req, resp.Messages); err != nil {
|
||||
return ChatResponse{}, err
|
||||
return conversation.ChatResponse{}, err
|
||||
}
|
||||
return ChatResponse{
|
||||
return conversation.ChatResponse{
|
||||
Messages: resp.Messages,
|
||||
Skills: resp.Skills,
|
||||
Model: rc.model.ModelID,
|
||||
@@ -331,11 +312,8 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
return fmt.Errorf("schedule command is required")
|
||||
}
|
||||
|
||||
chatID := payload.ChatID
|
||||
if strings.TrimSpace(chatID) == "" {
|
||||
chatID = "schedule-" + payload.ID
|
||||
}
|
||||
req := ChatRequest{
|
||||
chatID := "schedule-" + payload.ID
|
||||
req := conversation.ChatRequest{
|
||||
BotID: botID,
|
||||
ChatID: chatID,
|
||||
Query: payload.Command,
|
||||
@@ -347,22 +325,12 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
return err
|
||||
}
|
||||
|
||||
schedulePayload := rc.payload
|
||||
schedulePayload.Identity.ChannelIdentityID = strings.TrimSpace(payload.OwnerUserID)
|
||||
schedulePayload.Identity.DisplayName = "Scheduler"
|
||||
|
||||
triggerReq := triggerScheduleRequest{
|
||||
Model: rc.payload.Model,
|
||||
ActiveContextTime: rc.payload.ActiveContextTime,
|
||||
Channels: rc.payload.Channels,
|
||||
CurrentChannel: rc.payload.CurrentChannel,
|
||||
AllowedActions: rc.payload.AllowedActions,
|
||||
Messages: rc.payload.Messages,
|
||||
Skills: rc.payload.Skills,
|
||||
UsableSkills: rc.payload.UsableSkills,
|
||||
Identity: gatewayIdentity{
|
||||
BotID: rc.payload.Identity.BotID,
|
||||
ContainerID: rc.payload.Identity.ContainerID,
|
||||
ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID),
|
||||
DisplayName: "Scheduler",
|
||||
},
|
||||
Attachments: rc.payload.Attachments,
|
||||
gatewayRequest: schedulePayload,
|
||||
Schedule: gatewaySchedule{
|
||||
ID: payload.ID,
|
||||
Name: payload.Name,
|
||||
@@ -383,8 +351,8 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
|
||||
// --- StreamChat ---
|
||||
|
||||
// StreamChat sends a streaming chat request to the agent gateway.
|
||||
func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) {
|
||||
chunkCh := make(chan StreamChunk)
|
||||
func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) {
|
||||
chunkCh := make(chan conversation.StreamChunk)
|
||||
errCh := make(chan error, 1)
|
||||
r.logger.Info("gateway stream start",
|
||||
slog.String("bot_id", req.BotID),
|
||||
@@ -513,7 +481,7 @@ func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerSched
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error {
|
||||
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req conversation.ChatRequest, chunkCh chan<- conversation.StreamChunk) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -564,7 +532,7 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req C
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
chunkCh <- StreamChunk([]byte(data))
|
||||
chunkCh <- conversation.StreamChunk([]byte(data))
|
||||
|
||||
if stored {
|
||||
continue
|
||||
@@ -579,7 +547,7 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req C
|
||||
}
|
||||
|
||||
// tryStoreStream attempts to extract final messages from a stream event and persist them.
|
||||
func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) {
|
||||
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, eventType, data string) (bool, error) {
|
||||
// event: done + data: {messages: [...]}
|
||||
if eventType == "done" {
|
||||
var resp gatewayResponse
|
||||
@@ -590,10 +558,10 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventTyp
|
||||
|
||||
// data: {"type":"text_delta"|"agent_end"|"done", ...}
|
||||
var envelope struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Messages []ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Messages []conversation.ModelMessage `json:"messages"`
|
||||
Skills []string `json:"skills"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &envelope); err == nil {
|
||||
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
|
||||
@@ -630,12 +598,13 @@ func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit strin
|
||||
}
|
||||
}
|
||||
}
|
||||
r.logger.Warn("no container found for bot, using fallback", slog.String("bot_id", botID))
|
||||
return "mcp-" + botID
|
||||
}
|
||||
|
||||
// --- message loading ---
|
||||
|
||||
func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) {
|
||||
func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]conversation.ModelMessage, error) {
|
||||
if r.messageService == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -644,12 +613,13 @@ func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []ModelMessage
|
||||
var result []conversation.ModelMessage
|
||||
for _, m := range msgs {
|
||||
var mm ModelMessage
|
||||
var mm conversation.ModelMessage
|
||||
if err := json.Unmarshal(m.Content, &mm); err != nil {
|
||||
// Fallback: treat content as text string.
|
||||
mm = ModelMessage{Role: m.Role, Content: m.Content}
|
||||
r.logger.Warn("loadMessages: content unmarshal failed, treating as raw text",
|
||||
slog.String("chat_id", chatID), slog.Any("error", err))
|
||||
mm = conversation.ModelMessage{Role: m.Role, Content: m.Content}
|
||||
} else {
|
||||
mm.Role = m.Role
|
||||
}
|
||||
@@ -663,7 +633,7 @@ type memoryContextItem struct {
|
||||
Item memory.MemoryItem
|
||||
}
|
||||
|
||||
func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage {
|
||||
func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversation.ChatRequest) *conversation.ModelMessage {
|
||||
if r.memoryService == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -680,7 +650,7 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest
|
||||
Filters: map[string]any{
|
||||
"namespace": sharedMemoryNamespace,
|
||||
"scopeId": req.BotID,
|
||||
"botId": req.BotID,
|
||||
"bot_id": req.BotID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -732,16 +702,16 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest
|
||||
if payload == "" {
|
||||
return nil
|
||||
}
|
||||
msg := ModelMessage{
|
||||
msg := conversation.ModelMessage{
|
||||
Role: "system",
|
||||
Content: NewTextContent(payload),
|
||||
Content: conversation.NewTextContent(payload),
|
||||
}
|
||||
return &msg
|
||||
}
|
||||
|
||||
// --- store helpers ---
|
||||
|
||||
func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error {
|
||||
func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.ChatRequest) error {
|
||||
if r.messageService == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -753,9 +723,9 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
message := ModelMessage{
|
||||
message := conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: NewTextContent(text),
|
||||
Content: conversation.NewTextContent(text),
|
||||
}
|
||||
content, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
@@ -776,10 +746,10 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) erro
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error {
|
||||
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage) error {
|
||||
// Add user query as the first message if not already present in the round.
|
||||
// This ensures the user's prompt is persisted alongside the assistant's response.
|
||||
fullRound := make([]ModelMessage, 0, len(messages)+1)
|
||||
fullRound := make([]conversation.ModelMessage, 0, len(messages)+1)
|
||||
hasUserQuery := false
|
||||
for _, m := range messages {
|
||||
if m.Role == "user" && m.TextContent() == req.Query {
|
||||
@@ -788,9 +758,9 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M
|
||||
}
|
||||
}
|
||||
if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" {
|
||||
fullRound = append(fullRound, ModelMessage{
|
||||
fullRound = append(fullRound, conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: NewTextContent(req.Query),
|
||||
Content: conversation.NewTextContent(req.Query),
|
||||
})
|
||||
}
|
||||
for _, m := range messages {
|
||||
@@ -809,7 +779,7 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) {
|
||||
func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage) {
|
||||
if r.messageService == nil {
|
||||
return
|
||||
}
|
||||
@@ -821,6 +791,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
||||
for _, msg := range messages {
|
||||
content, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
r.logger.Warn("storeMessages: marshal failed", slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
messageSenderChannelIdentityID := ""
|
||||
@@ -852,7 +823,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
|
||||
}
|
||||
}
|
||||
|
||||
func buildRouteMetadata(req ChatRequest) map[string]any {
|
||||
func buildRouteMetadata(req conversation.ChatRequest) map[string]any {
|
||||
if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -866,25 +837,17 @@ func buildRouteMetadata(req ChatRequest) map[string]any {
|
||||
return meta
|
||||
}
|
||||
|
||||
func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) {
|
||||
func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req conversation.ChatRequest) (string, string) {
|
||||
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
|
||||
userID := strings.TrimSpace(req.UserID)
|
||||
|
||||
channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID)
|
||||
userAsUserValid := r.isExistingUserID(ctx, userID)
|
||||
userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID)
|
||||
|
||||
senderChannelIdentityID := ""
|
||||
switch {
|
||||
case channelIdentityValid:
|
||||
if r.isExistingChannelIdentityID(ctx, channelIdentityID) {
|
||||
senderChannelIdentityID = channelIdentityID
|
||||
case userAsChannelIdentityValid && !userAsUserValid:
|
||||
// Some flows may carry channel_identity_id in req.UserID.
|
||||
senderChannelIdentityID = userID
|
||||
}
|
||||
|
||||
senderUserID := ""
|
||||
if userAsUserValid {
|
||||
if r.isExistingUserID(ctx, userID) {
|
||||
senderUserID = userID
|
||||
}
|
||||
if senderUserID == "" && senderChannelIdentityID != "" {
|
||||
@@ -936,14 +899,14 @@ func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelI
|
||||
|
||||
// resolveDisplayName returns the best available display name for the request identity:
|
||||
// req.DisplayName if set, else channel identity's display_name, else linked user's display_name, else "User".
|
||||
func (r *Resolver) resolveDisplayName(ctx context.Context, req ChatRequest) string {
|
||||
func (r *Resolver) resolveDisplayName(ctx context.Context, req conversation.ChatRequest) string {
|
||||
if name := strings.TrimSpace(req.DisplayName); name != "" {
|
||||
return name
|
||||
}
|
||||
if r.queries == nil {
|
||||
return "User"
|
||||
}
|
||||
channelIdentityID := firstNonEmpty(req.SourceChannelIdentityID, req.UserID)
|
||||
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
|
||||
if channelIdentityID == "" {
|
||||
return "User"
|
||||
}
|
||||
@@ -975,7 +938,7 @@ func (r *Resolver) resolveDisplayName(ctx context.Context, req ChatRequest) stri
|
||||
return "User"
|
||||
}
|
||||
|
||||
func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) {
|
||||
func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []conversation.ModelMessage) {
|
||||
if r.memoryService == nil {
|
||||
return
|
||||
}
|
||||
@@ -1004,7 +967,7 @@ func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Me
|
||||
filters := map[string]any{
|
||||
"namespace": namespace,
|
||||
"scopeId": scopeID,
|
||||
"botId": botID,
|
||||
"bot_id": botID,
|
||||
}
|
||||
if _, err := r.memoryService.Add(ctx, memory.AddRequest{
|
||||
Messages: msgs,
|
||||
@@ -1021,21 +984,19 @@ func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Me
|
||||
|
||||
// --- model selection ---
|
||||
|
||||
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.LlmProvider, error) {
|
||||
if r.modelsService == nil {
|
||||
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
|
||||
}
|
||||
modelID := strings.TrimSpace(req.Model)
|
||||
providerFilter := strings.TrimSpace(req.Provider)
|
||||
|
||||
// Priority: request model > chat settings > bot settings > user settings.
|
||||
// Priority: request model > chat settings > bot settings.
|
||||
if modelID == "" && providerFilter == "" {
|
||||
if value := strings.TrimSpace(cs.ModelID); value != "" {
|
||||
modelID = value
|
||||
} else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" {
|
||||
modelID = value
|
||||
} else if value := strings.TrimSpace(us.ChatModelID); value != "" {
|
||||
modelID = value
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1100,29 +1061,9 @@ func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([
|
||||
|
||||
// --- settings ---
|
||||
|
||||
type resolvedUserSettings struct {
|
||||
ChatModelID string
|
||||
}
|
||||
|
||||
func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (resolvedUserSettings, error) {
|
||||
if r.settingsService == nil || strings.TrimSpace(userID) == "" {
|
||||
return resolvedUserSettings{}, nil
|
||||
}
|
||||
s, err := r.settingsService.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return resolvedUserSettings{}, err
|
||||
}
|
||||
return resolvedUserSettings{
|
||||
ChatModelID: strings.TrimSpace(s.ChatModelID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) {
|
||||
if r.settingsService == nil {
|
||||
return settings.Settings{
|
||||
MaxContextLoadTime: settings.DefaultMaxContextLoadTime,
|
||||
Language: settings.DefaultLanguage,
|
||||
}, nil
|
||||
return settings.Settings{}, fmt.Errorf("settings service not configured")
|
||||
}
|
||||
return r.settingsService.GetBot(ctx, botID)
|
||||
}
|
||||
@@ -1140,8 +1081,8 @@ func normalizeClientType(clientType string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeMessages(messages []ModelMessage) []ModelMessage {
|
||||
cleaned := make([]ModelMessage, 0, len(messages))
|
||||
func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage {
|
||||
cleaned := make([]conversation.ModelMessage, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
if strings.TrimSpace(msg.Role) == "" {
|
||||
continue
|
||||
@@ -1196,9 +1137,9 @@ func nonNilStrings(s []string) []string {
|
||||
return s
|
||||
}
|
||||
|
||||
func nonNilModelMessages(m []ModelMessage) []ModelMessage {
|
||||
func nonNilModelMessages(m []conversation.ModelMessage) []conversation.ModelMessage {
|
||||
if m == nil {
|
||||
return []ModelMessage{}
|
||||
return []conversation.ModelMessage{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/memory"
|
||||
)
|
||||
|
||||
@@ -14,7 +15,7 @@ func TestLoadMemoryContextMessage_NoMemoryService(t *testing.T) {
|
||||
memoryService: nil,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{
|
||||
msg := resolver.loadMemoryContextMessage(context.Background(), conversation.ChatRequest{
|
||||
Query: "hello",
|
||||
BotID: "bot-1",
|
||||
ChatID: "chat-1",
|
||||
@@ -29,7 +30,7 @@ func TestLoadMemoryContextMessage_SearchFailureFallback(t *testing.T) {
|
||||
memoryService: &memory.Service{},
|
||||
logger: slog.Default(),
|
||||
}
|
||||
msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{
|
||||
msg := resolver.loadMemoryContextMessage(context.Background(), conversation.ChatRequest{
|
||||
Query: "hello",
|
||||
BotID: "bot-1",
|
||||
ChatID: "chat-1",
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
)
|
||||
|
||||
func TestPostTriggerSchedule_Endpoint(t *testing.T) {
|
||||
@@ -21,7 +23,7 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
capturedBody, _ = io.ReadAll(r.Body)
|
||||
resp := gatewayResponse{
|
||||
Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}},
|
||||
Messages: []conversation.ModelMessage{{Role: "assistant", Content: conversation.NewTextContent("ok")}},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
@@ -36,23 +38,25 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) {
|
||||
|
||||
maxCalls := 5
|
||||
req := triggerScheduleRequest{
|
||||
Model: gatewayModelConfig{
|
||||
ModelID: "gpt-4",
|
||||
ClientType: "openai",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com",
|
||||
gatewayRequest: gatewayRequest{
|
||||
Model: gatewayModelConfig{
|
||||
ModelID: "gpt-4",
|
||||
ClientType: "openai",
|
||||
APIKey: "sk-test",
|
||||
BaseURL: "https://api.openai.com",
|
||||
},
|
||||
ActiveContextTime: 1440,
|
||||
Channels: []string{},
|
||||
Messages: []conversation.ModelMessage{},
|
||||
Skills: []string{},
|
||||
Identity: gatewayIdentity{
|
||||
BotID: "bot-123",
|
||||
ContainerID: "mcp-bot-123",
|
||||
ChannelIdentityID: "owner-user-1",
|
||||
DisplayName: "Scheduler",
|
||||
},
|
||||
Attachments: []any{},
|
||||
},
|
||||
ActiveContextTime: 1440,
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Identity: gatewayIdentity{
|
||||
BotID: "bot-123",
|
||||
ContainerID: "mcp-bot-123",
|
||||
ChannelIdentityID: "owner-user-1",
|
||||
DisplayName: "Scheduler",
|
||||
},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{
|
||||
ID: "sched-1",
|
||||
Name: "daily report",
|
||||
@@ -102,7 +106,7 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) {
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
resp := gatewayResponse{Messages: []ModelMessage{}}
|
||||
resp := gatewayResponse{Messages: []conversation.ModelMessage{}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
@@ -114,11 +118,13 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) {
|
||||
}
|
||||
|
||||
req := triggerScheduleRequest{
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
gatewayRequest: gatewayRequest{
|
||||
Channels: []string{},
|
||||
Messages: []conversation.ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
}
|
||||
|
||||
_, err := resolver.postTriggerSchedule(context.Background(), req, "")
|
||||
@@ -144,11 +150,13 @@ func TestPostTriggerSchedule_GatewayError(t *testing.T) {
|
||||
}
|
||||
|
||||
req := triggerScheduleRequest{
|
||||
Channels: []string{},
|
||||
Messages: []ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
gatewayRequest: gatewayRequest{
|
||||
Channels: []string{},
|
||||
Messages: []conversation.ModelMessage{},
|
||||
Skills: []string{},
|
||||
Attachments: []any{},
|
||||
},
|
||||
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
|
||||
}
|
||||
|
||||
_, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok")
|
||||
|
||||
@@ -3,12 +3,13 @@ package flow
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
)
|
||||
|
||||
// Runner defines conversation execution behavior for sync, stream, and scheduled flows.
|
||||
type Runner interface {
|
||||
Chat(ctx context.Context, req ChatRequest) (ChatResponse, error)
|
||||
StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error)
|
||||
Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error)
|
||||
StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error)
|
||||
TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error
|
||||
}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package flow
|
||||
|
||||
import "github.com/memohai/memoh/internal/conversation"
|
||||
|
||||
type ModelMessage = conversation.ModelMessage
|
||||
type ContentPart = conversation.ContentPart
|
||||
type ToolCall = conversation.ToolCall
|
||||
type ToolCallFunction = conversation.ToolCallFunction
|
||||
type AssistantOutput = conversation.AssistantOutput
|
||||
type ChatRequest = conversation.ChatRequest
|
||||
type ChatResponse = conversation.ChatResponse
|
||||
type StreamChunk = conversation.StreamChunk
|
||||
type Settings = conversation.Settings
|
||||
|
||||
var NewTextContent = conversation.NewTextContent
|
||||
@@ -4,7 +4,7 @@ import "context"
|
||||
|
||||
// Reader defines conversation lookup behavior.
|
||||
type Reader interface {
|
||||
Get(ctx context.Context, conversationID string) (Chat, error)
|
||||
Get(ctx context.Context, conversationID string) (Conversation, error)
|
||||
}
|
||||
|
||||
// ParticipantChecker defines participant membership checks.
|
||||
@@ -16,5 +16,5 @@ type ParticipantChecker interface {
|
||||
type Accessor interface {
|
||||
Reader
|
||||
ParticipantChecker
|
||||
GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error)
|
||||
GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ConversationReadAccess, error)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -40,24 +40,24 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||
}
|
||||
|
||||
// Create creates a new conversation and adds the creator as owner.
|
||||
func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Chat, error) {
|
||||
func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Conversation, error) {
|
||||
kind := strings.TrimSpace(req.Kind)
|
||||
if kind == "" {
|
||||
kind = KindDirect
|
||||
}
|
||||
if kind != KindDirect && kind != KindGroup && kind != KindThread {
|
||||
return Chat{}, fmt.Errorf("invalid conversation kind: %s", kind)
|
||||
return Conversation{}, fmt.Errorf("invalid conversation kind: %s", kind)
|
||||
}
|
||||
|
||||
pgBotID, err := parseUUID(botID)
|
||||
if err != nil {
|
||||
return Chat{}, fmt.Errorf("invalid bot id: %w", err)
|
||||
return Conversation{}, fmt.Errorf("invalid bot id: %w", err)
|
||||
}
|
||||
pgChannelIdentityID := pgtype.UUID{}
|
||||
if strings.TrimSpace(channelIdentityID) != "" {
|
||||
pgChannelIdentityID, err = parseUUID(channelIdentityID)
|
||||
if err != nil {
|
||||
return Chat{}, fmt.Errorf("invalid channel identity id: %w", err)
|
||||
return Conversation{}, fmt.Errorf("invalid channel identity id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,13 +65,13 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
|
||||
if kind == KindThread && strings.TrimSpace(req.ParentChatID) != "" {
|
||||
pgParent, err = parseUUID(req.ParentChatID)
|
||||
if err != nil {
|
||||
return Chat{}, fmt.Errorf("invalid parent conversation id: %w", err)
|
||||
return Conversation{}, fmt.Errorf("invalid parent conversation id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
metadata, err := json.Marshal(nonNilMap(req.Metadata))
|
||||
if err != nil {
|
||||
return Chat{}, fmt.Errorf("marshal conversation metadata: %w", err)
|
||||
return Conversation{}, fmt.Errorf("marshal conversation metadata: %w", err)
|
||||
}
|
||||
|
||||
row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{
|
||||
@@ -83,7 +83,7 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
|
||||
Metadata: metadata,
|
||||
})
|
||||
if err != nil {
|
||||
return Chat{}, fmt.Errorf("create conversation: %w", err)
|
||||
return Conversation{}, fmt.Errorf("create conversation: %w", err)
|
||||
}
|
||||
|
||||
// Add creator as owner when the channel identity is available.
|
||||
@@ -93,7 +93,7 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
|
||||
UserID: pgChannelIdentityID,
|
||||
Role: RoleOwner,
|
||||
}); err != nil {
|
||||
return Chat{}, fmt.Errorf("add owner participant: %w", err)
|
||||
return Conversation{}, fmt.Errorf("add owner participant: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
|
||||
if err := s.queries.CopyParticipantsToChat(ctx, sqlc.CopyParticipantsToChatParams{
|
||||
ChatID: pgParent,
|
||||
ChatID2: row.ID,
|
||||
}); err != nil && s.logger != nil {
|
||||
}); err != nil {
|
||||
s.logger.Warn("copy parent participants failed", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
@@ -111,30 +111,30 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
|
||||
}
|
||||
|
||||
// Get returns a conversation by ID.
|
||||
func (s *Service) Get(ctx context.Context, conversationID string) (Chat, error) {
|
||||
func (s *Service) Get(ctx context.Context, conversationID string) (Conversation, error) {
|
||||
pgID, err := parseUUID(conversationID)
|
||||
if err != nil {
|
||||
return Chat{}, ErrChatNotFound
|
||||
return Conversation{}, ErrChatNotFound
|
||||
}
|
||||
row, err := s.queries.GetChatByID(ctx, pgID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return Chat{}, ErrChatNotFound
|
||||
return Conversation{}, ErrChatNotFound
|
||||
}
|
||||
return Chat{}, err
|
||||
return Conversation{}, err
|
||||
}
|
||||
return toChatFromGet(row), nil
|
||||
}
|
||||
|
||||
// GetReadAccess resolves whether a user can read a conversation.
|
||||
func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error) {
|
||||
func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ConversationReadAccess, error) {
|
||||
pgConversationID, err := parseUUID(conversationID)
|
||||
if err != nil {
|
||||
return ChatReadAccess{}, ErrPermissionDenied
|
||||
return ConversationReadAccess{}, ErrPermissionDenied
|
||||
}
|
||||
pgChannelIdentityID, err := parseUUID(channelIdentityID)
|
||||
if err != nil {
|
||||
return ChatReadAccess{}, ErrPermissionDenied
|
||||
return ConversationReadAccess{}, ErrPermissionDenied
|
||||
}
|
||||
row, err := s.queries.GetChatReadAccessByUser(ctx, sqlc.GetChatReadAccessByUserParams{
|
||||
ChatID: pgConversationID,
|
||||
@@ -142,11 +142,11 @@ func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIden
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ChatReadAccess{}, ErrPermissionDenied
|
||||
return ConversationReadAccess{}, ErrPermissionDenied
|
||||
}
|
||||
return ChatReadAccess{}, err
|
||||
return ConversationReadAccess{}, err
|
||||
}
|
||||
return ChatReadAccess{
|
||||
return ConversationReadAccess{
|
||||
AccessMode: row.AccessMode,
|
||||
ParticipantRole: strings.TrimSpace(row.ParticipantRole),
|
||||
LastObservedAt: pgTimePtr(row.LastObservedAt),
|
||||
@@ -154,7 +154,7 @@ func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIden
|
||||
}
|
||||
|
||||
// ListByBotAndChannelIdentity returns all visible conversations for a bot and channel identity.
|
||||
func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ChatListItem, error) {
|
||||
func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ConversationListItem, error) {
|
||||
pgBotID, err := parseUUID(botID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -170,7 +170,7 @@ func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channe
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conversations := make([]ChatListItem, 0, len(rows))
|
||||
conversations := make([]ConversationListItem, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
conversations = append(conversations, toChatListItem(row))
|
||||
}
|
||||
@@ -178,7 +178,7 @@ func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channe
|
||||
}
|
||||
|
||||
// ListThreads returns threads for a parent conversation.
|
||||
func (s *Service) ListThreads(ctx context.Context, parentConversationID string) ([]Chat, error) {
|
||||
func (s *Service) ListThreads(ctx context.Context, parentConversationID string) ([]Conversation, error) {
|
||||
pgID, err := parseUUID(parentConversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -187,7 +187,7 @@ func (s *Service) ListThreads(ctx context.Context, parentConversationID string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conversations := make([]Chat, 0, len(rows))
|
||||
conversations := make([]Conversation, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
conversations = append(conversations, toChatFromThread(row))
|
||||
}
|
||||
@@ -296,7 +296,7 @@ func (s *Service) RemoveParticipant(ctx context.Context, conversationID, channel
|
||||
func (s *Service) GetSettings(ctx context.Context, conversationID string) (Settings, error) {
|
||||
pgID, err := parseUUID(conversationID)
|
||||
if err != nil {
|
||||
return defaultSettings(conversationID), nil
|
||||
return Settings{}, fmt.Errorf("invalid conversation id: %w", err)
|
||||
}
|
||||
row, err := s.queries.GetChatSettings(ctx, pgID)
|
||||
if err != nil {
|
||||
@@ -332,7 +332,7 @@ func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req
|
||||
return toSettingsFromUpsert(row), nil
|
||||
}
|
||||
|
||||
func toChatFromCreate(row sqlc.CreateChatRow) Chat {
|
||||
func toChatFromCreate(row sqlc.CreateChatRow) Conversation {
|
||||
return toChatFields(
|
||||
row.ID,
|
||||
row.BotID,
|
||||
@@ -346,7 +346,7 @@ func toChatFromCreate(row sqlc.CreateChatRow) Chat {
|
||||
)
|
||||
}
|
||||
|
||||
func toChatFromGet(row sqlc.GetChatByIDRow) Chat {
|
||||
func toChatFromGet(row sqlc.GetChatByIDRow) Conversation {
|
||||
return toChatFields(
|
||||
row.ID,
|
||||
row.BotID,
|
||||
@@ -360,7 +360,7 @@ func toChatFromGet(row sqlc.GetChatByIDRow) Chat {
|
||||
)
|
||||
}
|
||||
|
||||
func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat {
|
||||
func toChatFromThread(row sqlc.ListThreadsByParentRow) Conversation {
|
||||
return toChatFields(
|
||||
row.ID,
|
||||
row.BotID,
|
||||
@@ -374,8 +374,8 @@ func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat {
|
||||
)
|
||||
}
|
||||
|
||||
func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Chat {
|
||||
return Chat{
|
||||
func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Conversation {
|
||||
return Conversation{
|
||||
ID: id.String(),
|
||||
BotID: botID.String(),
|
||||
Kind: kind,
|
||||
@@ -388,8 +388,8 @@ func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID,
|
||||
}
|
||||
}
|
||||
|
||||
func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem {
|
||||
return ChatListItem{
|
||||
func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ConversationListItem {
|
||||
return ConversationListItem{
|
||||
ID: row.ID.String(),
|
||||
BotID: row.BotID.String(),
|
||||
Kind: row.Kind,
|
||||
@@ -478,6 +478,8 @@ func parseJSONMap(data []byte) map[string]any {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(data, &m)
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err))
|
||||
}
|
||||
return m
|
||||
}
|
||||
+1
-1
@@ -175,7 +175,7 @@ func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) {
|
||||
t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind))
|
||||
}
|
||||
|
||||
var target *conversation.ChatListItem
|
||||
var target *conversation.ConversationListItem
|
||||
for i := range afterBind {
|
||||
if afterBind[i].ID == chatID {
|
||||
target = &afterBind[i]
|
||||
@@ -3,11 +3,12 @@ package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Chat kind constants.
|
||||
// Conversation kind constants.
|
||||
const (
|
||||
KindDirect = "direct"
|
||||
KindGroup = "group"
|
||||
@@ -21,7 +22,7 @@ const (
|
||||
RoleMember = "member"
|
||||
)
|
||||
|
||||
// Chat list access mode constants.
|
||||
// Conversation list access mode constants.
|
||||
const (
|
||||
AccessModeParticipant = "participant"
|
||||
AccessModeChannelIdentityObserved = "channel_identity_observed"
|
||||
@@ -63,11 +64,6 @@ type ConversationReadAccess struct {
|
||||
LastObservedAt *time.Time
|
||||
}
|
||||
|
||||
// Backward-compatible aliases while call sites migrate.
|
||||
type Chat = Conversation
|
||||
type ChatListItem = ConversationListItem
|
||||
type ChatReadAccess = ConversationReadAccess
|
||||
|
||||
// Participant represents a chat member.
|
||||
type Participant struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
@@ -155,7 +151,11 @@ func (m ModelMessage) HasContent() bool {
|
||||
|
||||
// NewTextContent creates a json.RawMessage from a plain string.
|
||||
func NewTextContent(text string) json.RawMessage {
|
||||
data, _ := json.Marshal(text)
|
||||
data, err := json.Marshal(text)
|
||||
if err != nil {
|
||||
slog.Warn("NewTextContent: marshal failed", slog.Any("error", err))
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -209,7 +209,6 @@ type ChatRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
MaxContextLoadTime int `json:"max_context_load_time,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Channels []string `json:"channels,omitempty"`
|
||||
CurrentChannel string `json:"current_channel,omitempty"`
|
||||
Messages []ModelMessage `json:"messages,omitempty"`
|
||||
|
||||
@@ -11,30 +11,6 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const clearChannelIdentityLinkedUser = `-- name: ClearChannelIdentityLinkedUser :one
|
||||
UPDATE channel_identities
|
||||
SET user_id = NULL, updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id, user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
func (q *Queries) ClearChannelIdentityLinkedUser(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) {
|
||||
row := q.db.QueryRow(ctx, clearChannelIdentityLinkedUser, id)
|
||||
var i ChannelIdentity
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.ChannelType,
|
||||
&i.ChannelSubjectID,
|
||||
&i.DisplayName,
|
||||
&i.AvatarUrl,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const createChannelIdentity = `-- name: CreateChannelIdentity :one
|
||||
INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
|
||||
@@ -46,32 +46,6 @@ func (q *Queries) GetContainerByBotID(ctx context.Context, botID pgtype.UUID) (C
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getContainerByContainerID = `-- name: GetContainerByContainerID :one
|
||||
SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE container_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetContainerByContainerID(ctx context.Context, containerID string) (Container, error) {
|
||||
row := q.db.QueryRow(ctx, getContainerByContainerID, containerID)
|
||||
var i Container
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.BotID,
|
||||
&i.ContainerID,
|
||||
&i.ContainerName,
|
||||
&i.Image,
|
||||
&i.Status,
|
||||
&i.Namespace,
|
||||
&i.AutoStart,
|
||||
&i.HostPath,
|
||||
&i.ContainerPath,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LastStartedAt,
|
||||
&i.LastStoppedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listAutoStartContainers = `-- name: ListAutoStartContainers :many
|
||||
SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE auto_start = true ORDER BY updated_at DESC
|
||||
`
|
||||
|
||||
@@ -428,7 +428,6 @@ SELECT
|
||||
FROM bots b
|
||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||
WHERE b.id = $1
|
||||
AND false
|
||||
ORDER BY b.created_at DESC
|
||||
`
|
||||
|
||||
@@ -485,6 +484,7 @@ SELECT
|
||||
b.display_name AS title,
|
||||
b.owner_user_id AS created_by_user_id,
|
||||
b.metadata AS metadata,
|
||||
chat_models.model_id AS model_id,
|
||||
b.created_at,
|
||||
b.updated_at,
|
||||
'participant'::text AS access_mode,
|
||||
@@ -495,6 +495,7 @@ SELECT
|
||||
NULL::timestamptz AS last_observed_at
|
||||
FROM bots b
|
||||
LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1
|
||||
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
|
||||
WHERE b.id = $2
|
||||
AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL)
|
||||
ORDER BY b.updated_at DESC
|
||||
@@ -513,6 +514,7 @@ type ListVisibleChatsByBotAndUserRow struct {
|
||||
Title pgtype.Text `json:"title"`
|
||||
CreatedByUserID pgtype.UUID `json:"created_by_user_id"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
ModelID pgtype.Text `json:"model_id"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
AccessMode string `json:"access_mode"`
|
||||
@@ -537,6 +539,7 @@ func (q *Queries) ListVisibleChatsByBotAndUser(ctx context.Context, arg ListVisi
|
||||
&i.Title,
|
||||
&i.CreatedByUserID,
|
||||
&i.Metadata,
|
||||
&i.ModelID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.AccessMode,
|
||||
@@ -582,26 +585,31 @@ func (q *Queries) TouchChat(ctx context.Context, chatID pgtype.UUID) error {
|
||||
}
|
||||
|
||||
const updateChatTitle = `-- name: UpdateChatTitle :one
|
||||
UPDATE bots
|
||||
SET display_name = $1,
|
||||
updated_at = now()
|
||||
WHERE id = $2
|
||||
RETURNING
|
||||
id,
|
||||
id AS bot_id,
|
||||
CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind,
|
||||
WITH updated AS (
|
||||
UPDATE bots
|
||||
SET display_name = $1,
|
||||
updated_at = now()
|
||||
WHERE bots.id = $2
|
||||
RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at
|
||||
)
|
||||
SELECT
|
||||
updated.id AS id,
|
||||
updated.id AS bot_id,
|
||||
CASE WHEN updated.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
|
||||
NULL::uuid AS parent_chat_id,
|
||||
display_name AS title,
|
||||
owner_user_id AS created_by_user_id,
|
||||
metadata,
|
||||
NULL::text AS model_id,
|
||||
created_at,
|
||||
updated_at
|
||||
updated.display_name AS title,
|
||||
updated.owner_user_id AS created_by_user_id,
|
||||
updated.metadata,
|
||||
chat_models.model_id AS model_id,
|
||||
updated.created_at,
|
||||
updated.updated_at
|
||||
FROM updated
|
||||
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id
|
||||
`
|
||||
|
||||
type UpdateChatTitleParams struct {
|
||||
Title pgtype.Text `json:"title"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
BotID pgtype.UUID `json:"bot_id"`
|
||||
}
|
||||
|
||||
type UpdateChatTitleRow struct {
|
||||
@@ -618,7 +626,7 @@ type UpdateChatTitleRow struct {
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams) (UpdateChatTitleRow, error) {
|
||||
row := q.db.QueryRow(ctx, updateChatTitle, arg.Title, arg.ID)
|
||||
row := q.db.QueryRow(ctx, updateChatTitle, arg.Title, arg.BotID)
|
||||
var i UpdateChatTitleRow
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
|
||||
@@ -35,33 +35,3 @@ func (q *Queries) InsertLifecycleEvent(ctx context.Context, arg InsertLifecycleE
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const listLifecycleEventsByContainerID = `-- name: ListLifecycleEventsByContainerID :many
|
||||
SELECT id, container_id, event_type, payload, created_at FROM lifecycle_events WHERE container_id = $1 ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
func (q *Queries) ListLifecycleEventsByContainerID(ctx context.Context, containerID string) ([]LifecycleEvent, error) {
|
||||
rows, err := q.db.Query(ctx, listLifecycleEventsByContainerID, containerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []LifecycleEvent
|
||||
for rows.Next() {
|
||||
var i LifecycleEvent
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ContainerID,
|
||||
&i.EventType,
|
||||
&i.Payload,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
@@ -140,6 +140,7 @@ FROM bot_history_messages m
|
||||
LEFT JOIN channel_identities ci ON ci.id = m.sender_channel_identity_id
|
||||
WHERE m.bot_id = $1
|
||||
ORDER BY m.created_at ASC
|
||||
LIMIT 10000
|
||||
`
|
||||
|
||||
type ListMessagesRow struct {
|
||||
|
||||
+13
-18
@@ -225,24 +225,19 @@ type Subagent struct {
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Username pgtype.Text `json:"username"`
|
||||
Email pgtype.Text `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
Role string `json:"role"`
|
||||
DisplayName pgtype.Text `json:"display_name"`
|
||||
AvatarUrl pgtype.Text `json:"avatar_url"`
|
||||
DataRoot pgtype.Text `json:"data_root"`
|
||||
LastLoginAt pgtype.Timestamptz `json:"last_login_at"`
|
||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
||||
MaxContextLoadTime int32 `json:"max_context_load_time"`
|
||||
Language string `json:"language"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
Username pgtype.Text `json:"username"`
|
||||
Email pgtype.Text `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
Role string `json:"role"`
|
||||
DisplayName pgtype.Text `json:"display_name"`
|
||||
AvatarUrl pgtype.Text `json:"avatar_url"`
|
||||
DataRoot pgtype.Text `json:"data_root"`
|
||||
LastLoginAt pgtype.Timestamptz `json:"last_login_at"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserChannelBinding struct {
|
||||
|
||||
@@ -208,15 +208,6 @@ func (q *Queries) DeleteModelByModelID(ctx context.Context, modelID string) erro
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteModelVariant = `-- name: DeleteModelVariant :exec
|
||||
DELETE FROM model_variants WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteModelVariant(ctx context.Context, id pgtype.UUID) error {
|
||||
_, err := q.db.Exec(ctx, deleteModelVariant, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getLlmProviderByID = `-- name: GetLlmProviderByID :one
|
||||
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE id = $1
|
||||
`
|
||||
@@ -299,25 +290,6 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getModelVariantByID = `-- name: GetModelVariantByID :one
|
||||
SELECT id, model_uuid, variant_id, weight, metadata, created_at, updated_at FROM model_variants WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetModelVariantByID(ctx context.Context, id pgtype.UUID) (ModelVariant, error) {
|
||||
row := q.db.QueryRow(ctx, getModelVariantByID, id)
|
||||
var i ModelVariant
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ModelUuid,
|
||||
&i.VariantID,
|
||||
&i.Weight,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listLlmProviders = `-- name: ListLlmProviders :many
|
||||
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
|
||||
ORDER BY created_at DESC
|
||||
@@ -421,40 +393,6 @@ func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pg
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listModelVariantsByVariantID = `-- name: ListModelVariantsByVariantID :many
|
||||
SELECT id, model_uuid, variant_id, weight, metadata, created_at, updated_at FROM model_variants
|
||||
WHERE variant_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID string) ([]ModelVariant, error) {
|
||||
rows, err := q.db.Query(ctx, listModelVariantsByVariantID, variantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ModelVariant
|
||||
for rows.Next() {
|
||||
var i ModelVariant
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ModelUuid,
|
||||
&i.VariantID,
|
||||
&i.Weight,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listModels = `-- name: ListModels :many
|
||||
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
|
||||
ORDER BY created_at DESC
|
||||
@@ -777,41 +715,3 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateModelVariant = `-- name: UpdateModelVariant :one
|
||||
UPDATE model_variants
|
||||
SET
|
||||
variant_id = $1,
|
||||
weight = $2,
|
||||
metadata = $3,
|
||||
updated_at = now()
|
||||
WHERE id = $4
|
||||
RETURNING id, model_uuid, variant_id, weight, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateModelVariantParams struct {
|
||||
VariantID string `json:"variant_id"`
|
||||
Weight int32 `json:"weight"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateModelVariant(ctx context.Context, arg UpdateModelVariantParams) (ModelVariant, error) {
|
||||
row := q.db.QueryRow(ctx, updateModelVariant,
|
||||
arg.VariantID,
|
||||
arg.Weight,
|
||||
arg.Metadata,
|
||||
arg.ID,
|
||||
)
|
||||
var i ModelVariant
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ModelUuid,
|
||||
&i.VariantID,
|
||||
&i.Weight,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -48,6 +48,8 @@ const getBotPreauthKey = `-- name: GetBotPreauthKey :one
|
||||
SELECT id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at
|
||||
FROM bot_preauth_keys
|
||||
WHERE token = $1
|
||||
AND used_at IS NULL
|
||||
AND (expires_at IS NULL OR expires_at > now())
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ UPDATE bots
|
||||
SET max_context_load_time = 1440,
|
||||
language = 'auto',
|
||||
allow_guest = false,
|
||||
chat_model_id = NULL,
|
||||
memory_model_id = NULL,
|
||||
embedding_model_id = NULL,
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
`
|
||||
@@ -66,35 +69,6 @@ func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSe
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSettingsByUserID = `-- name: GetSettingsByUserID :one
|
||||
SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type GetSettingsByUserIDRow struct {
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
||||
MaxContextLoadTime int32 `json:"max_context_load_time"`
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSettingsByUserID(ctx context.Context, id pgtype.UUID) (GetSettingsByUserIDRow, error) {
|
||||
row := q.db.QueryRow(ctx, getSettingsByUserID, id)
|
||||
var i GetSettingsByUserIDRow
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertBotSettings = `-- name: UpsertBotSettings :one
|
||||
WITH updated AS (
|
||||
UPDATE bots
|
||||
@@ -164,54 +138,3 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertUserSettings = `-- name: UpsertUserSettings :one
|
||||
UPDATE users
|
||||
SET chat_model_id = $2,
|
||||
memory_model_id = $3,
|
||||
embedding_model_id = $4,
|
||||
max_context_load_time = $5,
|
||||
language = $6,
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
|
||||
`
|
||||
|
||||
type UpsertUserSettingsParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
||||
MaxContextLoadTime int32 `json:"max_context_load_time"`
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
type UpsertUserSettingsRow struct {
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
ChatModelID pgtype.Text `json:"chat_model_id"`
|
||||
MemoryModelID pgtype.Text `json:"memory_model_id"`
|
||||
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
|
||||
MaxContextLoadTime int32 `json:"max_context_load_time"`
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpsertUserSettings(ctx context.Context, arg UpsertUserSettingsParams) (UpsertUserSettingsRow, error) {
|
||||
row := q.db.QueryRow(ctx, upsertUserSettings,
|
||||
arg.ID,
|
||||
arg.ChatModelID,
|
||||
arg.MemoryModelID,
|
||||
arg.EmbeddingModelID,
|
||||
arg.MaxContextLoadTime,
|
||||
arg.Language,
|
||||
)
|
||||
var i UpsertUserSettingsRow
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -41,34 +41,3 @@ func (q *Queries) InsertSnapshot(ctx context.Context, arg InsertSnapshotParams)
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const listSnapshotsByContainerID = `-- name: ListSnapshotsByContainerID :many
|
||||
SELECT id, container_id, parent_snapshot_id, snapshotter, digest, created_at FROM snapshots WHERE container_id = $1 ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSnapshotsByContainerID(ctx context.Context, containerID string) ([]Snapshot, error) {
|
||||
rows, err := q.db.Query(ctx, listSnapshotsByContainerID, containerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Snapshot
|
||||
for rows.Next() {
|
||||
var i Snapshot
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ContainerID,
|
||||
&i.ParentSnapshotID,
|
||||
&i.Snapshotter,
|
||||
&i.Digest,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
+11
-135
@@ -37,7 +37,7 @@ SET username = $1,
|
||||
data_root = $8,
|
||||
updated_at = now()
|
||||
WHERE id = $9
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateAccountParams struct {
|
||||
@@ -75,11 +75,6 @@ func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (U
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -91,7 +86,7 @@ func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (U
|
||||
const createUser = `-- name: CreateUser :one
|
||||
INSERT INTO users (is_active, metadata)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
@@ -112,11 +107,6 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -126,7 +116,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e
|
||||
}
|
||||
|
||||
const getAccountByIdentity = `-- name: GetAccountByIdentity :one
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 OR email = $1
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 OR email = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text) (User, error) {
|
||||
@@ -142,11 +132,6 @@ func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -156,7 +141,7 @@ func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text
|
||||
}
|
||||
|
||||
const getAccountByUserID = `-- name: GetAccountByUserID :one
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE id = $1
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at FROM users WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetAccountByUserID(ctx context.Context, userID pgtype.UUID) (User, error) {
|
||||
@@ -172,41 +157,6 @@ func (q *Queries) GetAccountByUserID(ctx context.Context, userID pgtype.UUID) (U
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getAccountByUsername = `-- name: GetAccountByUsername :one
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetAccountByUsername(ctx context.Context, username pgtype.Text) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getAccountByUsername, username)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Role,
|
||||
&i.DisplayName,
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -216,7 +166,7 @@ func (q *Queries) GetAccountByUsername(ctx context.Context, username pgtype.Text
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
`
|
||||
@@ -234,11 +184,6 @@ func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error)
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -248,7 +193,7 @@ func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error)
|
||||
}
|
||||
|
||||
const listAccounts = `-- name: ListAccounts :many
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users
|
||||
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at FROM users
|
||||
WHERE username IS NOT NULL
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
@@ -272,11 +217,6 @@ func (q *Queries) ListAccounts(ctx context.Context) ([]User, error) {
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -300,7 +240,7 @@ SET role = $1::user_role,
|
||||
is_active = $4,
|
||||
updated_at = now()
|
||||
WHERE id = $5
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateAccountAdminParams struct {
|
||||
@@ -330,11 +270,6 @@ func (q *Queries) UpdateAccountAdmin(ctx context.Context, arg UpdateAccountAdmin
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -348,7 +283,7 @@ UPDATE users
|
||||
SET last_login_at = now(),
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateAccountLastLogin(ctx context.Context, id pgtype.UUID) (User, error) {
|
||||
@@ -364,11 +299,6 @@ func (q *Queries) UpdateAccountLastLogin(ctx context.Context, id pgtype.UUID) (U
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -382,7 +312,7 @@ UPDATE users
|
||||
SET password_hash = $2,
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateAccountPasswordParams struct {
|
||||
@@ -403,11 +333,6 @@ func (q *Queries) UpdateAccountPassword(ctx context.Context, arg UpdateAccountPa
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -423,7 +348,7 @@ SET display_name = $2,
|
||||
is_active = $4,
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateAccountProfileParams struct {
|
||||
@@ -451,50 +376,6 @@ func (q *Queries) UpdateAccountProfile(ctx context.Context, arg UpdateAccountPro
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateUserStatus = `-- name: UpdateUserStatus :one
|
||||
UPDATE users
|
||||
SET is_active = $2,
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateUserStatusParams struct {
|
||||
ID pgtype.UUID `json:"id"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, updateUserStatus, arg.ID, arg.IsActive)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.Role,
|
||||
&i.DisplayName,
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
@@ -526,7 +407,7 @@ ON CONFLICT (username) DO UPDATE SET
|
||||
is_active = EXCLUDED.is_active,
|
||||
data_root = EXCLUDED.data_root,
|
||||
updated_at = now()
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
|
||||
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpsertAccountByUsernameParams struct {
|
||||
@@ -564,11 +445,6 @@ func (q *Queries) UpsertAccountByUsername(ctx context.Context, arg UpsertAccount
|
||||
&i.AvatarUrl,
|
||||
&i.DataRoot,
|
||||
&i.LastLoginAt,
|
||||
&i.ChatModelID,
|
||||
&i.MemoryModelID,
|
||||
&i.EmbeddingModelID,
|
||||
&i.MaxContextLoadTime,
|
||||
&i.Language,
|
||||
&i.IsActive,
|
||||
&i.Metadata,
|
||||
&i.CreatedAt,
|
||||
|
||||
@@ -3,16 +3,13 @@ package embeddings
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
)
|
||||
@@ -27,12 +24,11 @@ const (
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Type string
|
||||
Provider string
|
||||
Model string
|
||||
Dimensions int
|
||||
Input Input
|
||||
ChannelIdentityID string
|
||||
Type string
|
||||
Provider string
|
||||
Model string
|
||||
Dimensions int
|
||||
Input Input
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
@@ -44,7 +40,7 @@ type Input struct {
|
||||
type Usage struct {
|
||||
InputTokens int
|
||||
ImageTokens int
|
||||
VideoTokens int
|
||||
Duration int
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
@@ -165,7 +161,7 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
|
||||
Usage: Usage{
|
||||
InputTokens: usage.InputTokens,
|
||||
ImageTokens: usage.ImageTokens,
|
||||
VideoTokens: usage.Duration,
|
||||
Duration: usage.Duration,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -180,30 +176,6 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model
|
||||
return models.GetResponse{}, errors.New("models service not configured")
|
||||
}
|
||||
|
||||
// If no model specified and no provider specified, try to get per-user embedding model.
|
||||
if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.ChannelIdentityID) != "" {
|
||||
modelID, err := r.loadChannelIdentityEmbeddingModelID(ctx, req.ChannelIdentityID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, err
|
||||
}
|
||||
if modelID != "" {
|
||||
selected, err := r.modelsService.GetByModelID(ctx, modelID)
|
||||
if err != nil {
|
||||
return models.GetResponse{}, fmt.Errorf("settings embedding model not found: %w", err)
|
||||
}
|
||||
if selected.Type != models.ModelTypeEmbedding {
|
||||
return models.GetResponse{}, errors.New("settings embedding model is not an embedding model")
|
||||
}
|
||||
if req.Type == TypeMultimodal && !selected.IsMultimodal {
|
||||
return models.GetResponse{}, errors.New("settings embedding model does not support multimodal")
|
||||
}
|
||||
if req.Type == TypeText && selected.IsMultimodal {
|
||||
return models.GetResponse{}, errors.New("settings embedding model does not support text embeddings")
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
}
|
||||
|
||||
var candidates []models.GetResponse
|
||||
var err error
|
||||
if req.Provider != "" {
|
||||
@@ -257,22 +229,3 @@ func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.L
|
||||
copy(pgID.Bytes[:], parsed[:])
|
||||
return r.queries.GetLlmProviderByID(ctx, pgID)
|
||||
}
|
||||
|
||||
func (r *Resolver) loadChannelIdentityEmbeddingModelID(ctx context.Context, channelIdentityID string) (string, error) {
|
||||
if r.queries == nil {
|
||||
return "", nil
|
||||
}
|
||||
pgChannelIdentityID, err := db.ParseUUID(channelIdentityID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
row, err := r.queries.GetSettingsByUserID(ctx, pgChannelIdentityID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(row.EmbeddingModelID.String), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bind"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
)
|
||||
|
||||
// BindHandler manages channel identity bind code issuance via REST API.
|
||||
@@ -80,12 +78,5 @@ func (h *BindHandler) Issue(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *BindHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return userID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
@@ -7,9 +7,7 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
)
|
||||
|
||||
type ChannelHandler struct {
|
||||
@@ -161,12 +159,5 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
@@ -18,18 +18,17 @@ import (
|
||||
"github.com/containerd/containerd/v2/pkg/oci"
|
||||
"github.com/containerd/errdefs"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
ctr "github.com/memohai/memoh/internal/containerd"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
"github.com/memohai/memoh/internal/policy"
|
||||
)
|
||||
@@ -166,7 +165,10 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
if dataRoot == "" {
|
||||
dataRoot = config.DefaultDataRoot
|
||||
}
|
||||
dataRoot, _ = filepath.Abs(dataRoot)
|
||||
dataRoot, err = filepath.Abs(dataRoot)
|
||||
if err != nil {
|
||||
h.logger.Warn("filepath.Abs failed", slog.Any("error", err))
|
||||
}
|
||||
dataMount := strings.TrimSpace(h.cfg.DataMount)
|
||||
if dataMount == "" {
|
||||
dataMount = config.DefaultDataMount
|
||||
@@ -253,7 +255,9 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true})
|
||||
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}); err != nil {
|
||||
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
h.logger.Error("mcp container network setup failed",
|
||||
slog.String("container_id", containerID),
|
||||
slog.Any("error", netErr),
|
||||
@@ -303,7 +307,9 @@ func (h *ContainerdHandler) ensureContainerAndTask(ctx context.Context, containe
|
||||
if tasks[0].Status == tasktypes.Status_RUNNING {
|
||||
return nil
|
||||
}
|
||||
_ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true})
|
||||
if err := h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}); err != nil {
|
||||
h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
task, err := h.service.StartTask(ctx, containerID, &ctr.StartTaskOptions{
|
||||
@@ -313,7 +319,9 @@ func (h *ContainerdHandler) ensureContainerAndTask(ctx context.Context, containe
|
||||
return err
|
||||
}
|
||||
if err := ctr.SetupNetwork(ctx, task, containerID); err != nil {
|
||||
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true})
|
||||
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}); err != nil {
|
||||
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -324,10 +332,14 @@ func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (s
|
||||
if h.queries != nil {
|
||||
pgBotID, err := db.ParseUUID(botID)
|
||||
if err == nil {
|
||||
row, err := h.queries.GetContainerByBotID(ctx, pgBotID)
|
||||
if err == nil && strings.TrimSpace(row.ContainerID) != "" {
|
||||
row, dbErr := h.queries.GetContainerByBotID(ctx, pgBotID)
|
||||
if dbErr == nil && strings.TrimSpace(row.ContainerID) != "" {
|
||||
return row.ContainerID, nil
|
||||
}
|
||||
if dbErr != nil && !errors.Is(dbErr, pgx.ErrNoRows) {
|
||||
h.logger.Warn("botContainerID: db lookup failed",
|
||||
slog.String("bot_id", botID), slog.Any("error", dbErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: search by containerd label
|
||||
@@ -510,7 +522,9 @@ func (h *ContainerdHandler) StopContainer(c echo.Context) error {
|
||||
}); err != nil && !errdefs.IsNotFound(err) {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
_ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true})
|
||||
if err := h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}); err != nil {
|
||||
h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
if h.queries != nil {
|
||||
if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil {
|
||||
if dbErr := h.queries.UpdateContainerStopped(ctx, pgBotID); dbErr != nil {
|
||||
@@ -646,44 +660,11 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) {
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) && h.policyService != nil {
|
||||
allowGuest, policyErr := h.policyService.AllowGuest(ctx, botID)
|
||||
if policyErr == nil && allowGuest {
|
||||
bot, getErr := h.botService.Get(ctx, botID)
|
||||
if getErr == nil {
|
||||
return bot, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
// SetupBotContainer creates and starts the MCP container for a bot.
|
||||
@@ -701,7 +682,11 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string)
|
||||
if dataRoot == "" {
|
||||
dataRoot = config.DefaultDataRoot
|
||||
}
|
||||
dataRoot, _ = filepath.Abs(dataRoot)
|
||||
if absRoot, absErr := filepath.Abs(dataRoot); absErr != nil {
|
||||
h.logger.Warn("filepath.Abs failed", slog.Any("error", absErr))
|
||||
} else {
|
||||
dataRoot = absRoot
|
||||
}
|
||||
dataMount := strings.TrimSpace(h.cfg.DataMount)
|
||||
if dataMount == "" {
|
||||
dataMount = config.DefaultDataMount
|
||||
@@ -786,7 +771,9 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true})
|
||||
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}); err != nil {
|
||||
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
h.logger.Error("setup bot container: network setup failed",
|
||||
slog.String("bot_id", botID),
|
||||
slog.String("container_id", containerID),
|
||||
@@ -830,15 +817,21 @@ func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID strin
|
||||
|
||||
if task, taskErr := h.service.GetTask(ctx, containerID); taskErr == nil {
|
||||
h.logger.Info("CleanupBotContainer: removing network", slog.String("container_id", containerID))
|
||||
_ = ctr.RemoveNetwork(ctx, task, containerID)
|
||||
if err := ctr.RemoveNetwork(ctx, task, containerID); err != nil {
|
||||
h.logger.Warn("cleanup: remove network failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
h.logger.Info("CleanupBotContainer: stopping task", slog.String("container_id", containerID))
|
||||
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{
|
||||
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{
|
||||
Timeout: 5 * time.Second,
|
||||
Force: true,
|
||||
})
|
||||
}); err != nil {
|
||||
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
h.logger.Info("CleanupBotContainer: deleting task", slog.String("container_id", containerID))
|
||||
_ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true})
|
||||
if err := h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}); err != nil {
|
||||
h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err))
|
||||
}
|
||||
|
||||
h.logger.Info("CleanupBotContainer: deleting container", slog.String("container_id", containerID))
|
||||
if err := h.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
"github.com/memohai/memoh/internal/embeddings"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
@@ -48,7 +47,7 @@ type EmbeddingsResponse struct {
|
||||
type EmbeddingsUsage struct {
|
||||
InputTokens int `json:"input_tokens,omitempty"`
|
||||
ImageTokens int `json:"image_tokens,omitempty"`
|
||||
VideoTokens int `json:"video_tokens,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
func NewEmbeddingsHandler(log *slog.Logger, modelsService *models.Service, queries *sqlc.Queries) *EmbeddingsHandler {
|
||||
@@ -85,12 +84,6 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
|
||||
req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL)
|
||||
req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL)
|
||||
|
||||
userID := ""
|
||||
if c.Get("user") != nil {
|
||||
if value, err := auth.UserIDFromContext(c); err == nil {
|
||||
userID = value
|
||||
}
|
||||
}
|
||||
result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{
|
||||
Type: req.Type,
|
||||
Provider: req.Provider,
|
||||
@@ -101,7 +94,6 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
|
||||
ImageURL: req.Input.ImageURL,
|
||||
VideoURL: req.Input.VideoURL,
|
||||
},
|
||||
ChannelIdentityID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
message := err.Error()
|
||||
@@ -137,7 +129,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
|
||||
Usage: EmbeddingsUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
ImageTokens: result.Usage.ImageTokens,
|
||||
VideoTokens: result.Usage.VideoTokens,
|
||||
Duration: result.Usage.Duration,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -53,22 +53,6 @@ func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerI
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) callMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) (map[string]any, error) {
|
||||
session, err := h.getMCPSession(ctx, containerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session.call(ctx, req)
|
||||
}
|
||||
|
||||
func (h *ContainerdHandler) notifyMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) error {
|
||||
session, err := h.getMCPSession(ctx, containerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return session.notify(ctx, req)
|
||||
}
|
||||
|
||||
type mcpSession struct {
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
)
|
||||
|
||||
// RequireChannelIdentityID extracts and validates the channel identity ID from the request context.
|
||||
func RequireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
}
|
||||
|
||||
// AuthorizeBotAccess validates that the given identity has access to the specified bot.
|
||||
func AuthorizeBotAccess(ctx context.Context, botService *bots.Service, accountService *accounts.Service, channelIdentityID, botID string, policy bots.AccessPolicy) (bots.Bot, error) {
|
||||
if botService == nil || accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, policy)
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -13,12 +12,10 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/channel"
|
||||
"github.com/memohai/memoh/internal/channel/adapters/local"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
)
|
||||
|
||||
// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history.
|
||||
@@ -103,7 +100,9 @@ func (h *LocalChannelHandler) StreamMessages(c echo.Context) error {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, _ = writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data)))
|
||||
if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))); err != nil {
|
||||
return nil // client disconnected
|
||||
}
|
||||
writer.Flush()
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -185,33 +184,9 @@ func (h *LocalChannelHandler) ensureBotParticipant(ctx context.Context, botID, c
|
||||
}
|
||||
|
||||
func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
|
||||
}
|
||||
|
||||
@@ -11,9 +11,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/mcp"
|
||||
)
|
||||
|
||||
@@ -218,33 +216,9 @@ func (h *MCPHandler) Delete(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return userID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *MCPHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/memory"
|
||||
)
|
||||
|
||||
@@ -168,7 +166,7 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error {
|
||||
for _, scope := range scopes {
|
||||
filters := buildNamespaceFilters(scope.Namespace, scope.ScopeID, payload.Filters)
|
||||
if botID != "" {
|
||||
filters["botId"] = botID
|
||||
filters["bot_id"] = botID
|
||||
}
|
||||
req := memory.SearchRequest{
|
||||
Query: payload.Query,
|
||||
@@ -378,12 +376,5 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
|
||||
}
|
||||
|
||||
func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
@@ -15,12 +15,10 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/channel/identities"
|
||||
"github.com/memohai/memoh/internal/conversation"
|
||||
"github.com/memohai/memoh/internal/conversation/flow"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
messagepkg "github.com/memohai/memoh/internal/message"
|
||||
messageevent "github.com/memohai/memoh/internal/message/event"
|
||||
)
|
||||
@@ -85,7 +83,7 @@ func (h *MessageHandler) SendMessage(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var req flow.ChatRequest
|
||||
var req conversation.ChatRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
@@ -132,7 +130,7 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var req flow.ChatRequest
|
||||
var req conversation.ChatRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
@@ -200,10 +198,12 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error {
|
||||
h.logger.Error("conversation stream failed", slog.Any("error", err))
|
||||
if processingState == "started" {
|
||||
processingState = "failed"
|
||||
_ = writeSSEJSON(writer, flusher, map[string]string{
|
||||
"type": "processing_failed",
|
||||
"error": err.Error(),
|
||||
})
|
||||
if writeErr := writeSSEJSON(writer, flusher, map[string]string{
|
||||
"type": "processing_failed",
|
||||
"error": err.Error(),
|
||||
}); writeErr != nil {
|
||||
h.logger.Warn("write SSE processing_failed event failed", slog.Any("error", writeErr))
|
||||
}
|
||||
}
|
||||
errData := map[string]string{
|
||||
"type": "error",
|
||||
@@ -459,7 +459,7 @@ func (h *MessageHandler) DeleteMessages(c echo.Context) error {
|
||||
// resolveWebChannelIdentity resolves (web, user_id) to a channel identity and sets req.SourceChannelIdentityID.
|
||||
// Web uses user_id as the channel subject id (like Feishu open_id); the resolved ci has display_name and is linked to the user.
|
||||
// Returns the channel_identity_id to use for the rest of the flow, or the original userID if resolution is skipped/fails.
|
||||
func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID string, req *flow.ChatRequest) string {
|
||||
func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID string, req *conversation.ChatRequest) string {
|
||||
if strings.TrimSpace(req.CurrentChannel) != "web" || h.channelIdentitySvc == nil || strings.TrimSpace(userID) == "" {
|
||||
return userID
|
||||
}
|
||||
@@ -476,62 +476,23 @@ func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID s
|
||||
if err != nil {
|
||||
return userID
|
||||
}
|
||||
_ = h.channelIdentitySvc.LinkChannelIdentityToUser(ctx, ci.ID, userID)
|
||||
if err := h.channelIdentitySvc.LinkChannelIdentityToUser(ctx, ci.ID, userID); err != nil {
|
||||
h.logger.Warn("link channel identity to user failed", slog.Any("error", err))
|
||||
}
|
||||
req.SourceChannelIdentityID = ci.ID
|
||||
return ci.ID
|
||||
}
|
||||
|
||||
func (h *MessageHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *MessageHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
|
||||
}
|
||||
|
||||
func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot management access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
func (h *MessageHandler) requireParticipant(ctx context.Context, conversationID, channelIdentityID string) error {
|
||||
|
||||
@@ -4,26 +4,21 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/models"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
)
|
||||
|
||||
type ModelsHandler struct {
|
||||
service *models.Service
|
||||
settingsService *settings.Service
|
||||
logger *slog.Logger
|
||||
service *models.Service
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewModelsHandler(log *slog.Logger, service *models.Service, settingsService *settings.Service) *ModelsHandler {
|
||||
func NewModelsHandler(log *slog.Logger, service *models.Service) *ModelsHandler {
|
||||
return &ModelsHandler{
|
||||
service: service,
|
||||
settingsService: settingsService,
|
||||
logger: log.With(slog.String("handler", "models")),
|
||||
service: service,
|
||||
logger: log.With(slog.String("handler", "models")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +28,6 @@ func (h *ModelsHandler) Register(e *echo.Echo) {
|
||||
group.GET("", h.List)
|
||||
group.GET("/:id", h.GetByID)
|
||||
group.GET("/model/:modelId", h.GetByModelID)
|
||||
group.POST("/enable", h.Enable)
|
||||
group.PUT("/:id", h.UpdateByID)
|
||||
group.PUT("/model/:modelId", h.UpdateByModelID)
|
||||
group.DELETE("/:id", h.DeleteByID)
|
||||
@@ -145,67 +139,6 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
type EnableModelRequest struct {
|
||||
As string `json:"as"`
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
|
||||
// Enable godoc
|
||||
// @Summary Enable model for chat/memory/embedding
|
||||
// @Description Update the current user's settings to use the selected model
|
||||
// @Tags models
|
||||
// @Param payload body handlers.EnableModelRequest true "Enable model payload"
|
||||
// @Success 200 {object} settings.Settings
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
// @Failure 404 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /models/enable [post]
|
||||
func (h *ModelsHandler) Enable(c echo.Context) error {
|
||||
if h.settingsService == nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured")
|
||||
}
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req EnableModelRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
req.As = strings.ToLower(strings.TrimSpace(req.As))
|
||||
req.ModelID = strings.TrimSpace(req.ModelID)
|
||||
if req.As == "" || req.ModelID == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "as and model_id are required")
|
||||
}
|
||||
if req.As != "chat" && req.As != "memory" && req.As != "embedding" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "as must be one of chat, memory, embedding")
|
||||
}
|
||||
model, err := h.service.GetByModelID(c.Request().Context(), req.ModelID)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusNotFound, err.Error())
|
||||
}
|
||||
if req.As == "embedding" && model.Type != models.ModelTypeEmbedding {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "model is not an embedding model")
|
||||
}
|
||||
if (req.As == "chat" || req.As == "memory") && model.Type != models.ModelTypeChat {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "model is not a chat model")
|
||||
}
|
||||
upsert := settings.UpsertRequest{}
|
||||
switch req.As {
|
||||
case "chat":
|
||||
upsert.ChatModelID = req.ModelID
|
||||
case "memory":
|
||||
upsert.MemoryModelID = req.ModelID
|
||||
case "embedding":
|
||||
upsert.EmbeddingModelID = req.ModelID
|
||||
}
|
||||
resp, err := h.settingsService.Upsert(c.Request().Context(), userID, upsert)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// UpdateByID godoc
|
||||
// @Summary Update model by internal ID
|
||||
// @Description Update a model configuration by its internal UUID
|
||||
|
||||
@@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -10,9 +9,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/preauth"
|
||||
)
|
||||
|
||||
@@ -67,33 +64,9 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return userID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -10,9 +9,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/schedule"
|
||||
)
|
||||
|
||||
@@ -219,33 +216,9 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) {
|
||||
userID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(userID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return userID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/settings"
|
||||
)
|
||||
|
||||
@@ -130,33 +128,9 @@ func (h *SettingsHandler) Delete(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -10,9 +9,7 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/memohai/memoh/internal/accounts"
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/bots"
|
||||
"github.com/memohai/memoh/internal/identity"
|
||||
"github.com/memohai/memoh/internal/subagent"
|
||||
)
|
||||
|
||||
@@ -433,33 +430,9 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
if h.botService == nil || h.accountService == nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
|
||||
}
|
||||
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ type UsersHandler struct {
|
||||
}
|
||||
|
||||
type listMyIdentitiesResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Items []identities.ChannelIdentity `json:"items"`
|
||||
}
|
||||
|
||||
@@ -943,30 +943,9 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error {
|
||||
}
|
||||
|
||||
func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
|
||||
isAdmin, err := h.service.IsAdmin(ctx, channelIdentityID)
|
||||
if err != nil {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
|
||||
if err != nil {
|
||||
if errors.Is(err, bots.ErrBotNotFound) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
|
||||
}
|
||||
if errors.Is(err, bots.ErrBotAccessDenied) {
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
|
||||
}
|
||||
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return bot, nil
|
||||
return AuthorizeBotAccess(ctx, h.botService, h.service, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
|
||||
}
|
||||
|
||||
func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
|
||||
channelIdentityID, err := auth.UserIDFromContext(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
return channelIdentityID, nil
|
||||
return RequireChannelIdentityID(c)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type Connection struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Config map[string]any `json:"config"`
|
||||
Active bool `json:"active"`
|
||||
Active bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -29,7 +29,7 @@ type UpsertRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Config map[string]any `json:"config"`
|
||||
Active *bool `json:"active,omitempty"`
|
||||
Active *bool `json:"is_active,omitempty"`
|
||||
}
|
||||
|
||||
// ListResponse wraps MCP connection list responses.
|
||||
|
||||
+9
-11
@@ -175,7 +175,9 @@ func (m *Manager) Start(ctx context.Context, botID string) error {
|
||||
return err
|
||||
}
|
||||
if err := ctr.SetupNetwork(ctx, task, m.containerID(botID)); err != nil {
|
||||
_ = m.service.StopTask(ctx, m.containerID(botID), &ctr.StopTaskOptions{Force: true})
|
||||
if stopErr := m.service.StopTask(ctx, m.containerID(botID), &ctr.StopTaskOptions{Force: true}); stopErr != nil {
|
||||
m.logger.Warn("cleanup: stop task failed", slog.String("container_id", m.containerID(botID)), slog.Any("error", stopErr))
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -197,9 +199,13 @@ func (m *Manager) Delete(ctx context.Context, botID string) error {
|
||||
}
|
||||
|
||||
if task, taskErr := m.service.GetTask(ctx, m.containerID(botID)); taskErr == nil {
|
||||
_ = ctr.RemoveNetwork(ctx, task, m.containerID(botID))
|
||||
if err := ctr.RemoveNetwork(ctx, task, m.containerID(botID)); err != nil {
|
||||
m.logger.Warn("cleanup: remove network failed", slog.String("container_id", m.containerID(botID)), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
if err := m.service.DeleteTask(ctx, m.containerID(botID), &ctr.DeleteTaskOptions{Force: true}); err != nil {
|
||||
m.logger.Warn("cleanup: delete task failed", slog.String("container_id", m.containerID(botID)), slog.Any("error", err))
|
||||
}
|
||||
_ = m.service.DeleteTask(ctx, m.containerID(botID), &ctr.DeleteTaskOptions{Force: true})
|
||||
return m.service.DeleteContainer(ctx, m.containerID(botID), &ctr.DeleteContainerOptions{
|
||||
CleanupSnapshot: true,
|
||||
})
|
||||
@@ -347,14 +353,6 @@ func (m *Manager) execWithCaptureContainerd(ctx context.Context, req ExecRequest
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sshShellQuote wraps a string in single quotes for safe SSH transport.
|
||||
func sshShellQuote(s string) string {
|
||||
if s == "" {
|
||||
return "''"
|
||||
}
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
// DataDir returns the host data directory for a bot.
|
||||
func (m *Manager) DataDir(botID string) (string, error) {
|
||||
if err := validateBotID(botID); err != nil {
|
||||
|
||||
@@ -134,7 +134,7 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
|
||||
Filters: map[string]any{
|
||||
"namespace": sharedMemoryNamespace,
|
||||
"scopeId": botID,
|
||||
"botId": botID,
|
||||
"bot_id": botID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -23,15 +23,15 @@ func (f *fakeSearcher) Search(ctx context.Context, req memory.SearchRequest) (me
|
||||
}
|
||||
|
||||
type fakeChatAccessor struct {
|
||||
chat conversation.Chat
|
||||
chat conversation.Conversation
|
||||
getErr error
|
||||
participant bool
|
||||
participantErr error
|
||||
}
|
||||
|
||||
func (f *fakeChatAccessor) Get(ctx context.Context, conversationID string) (conversation.Chat, error) {
|
||||
func (f *fakeChatAccessor) Get(ctx context.Context, conversationID string) (conversation.Conversation, error) {
|
||||
if f.getErr != nil {
|
||||
return conversation.Chat{}, f.getErr
|
||||
return conversation.Conversation{}, f.getErr
|
||||
}
|
||||
return f.chat, nil
|
||||
}
|
||||
@@ -43,8 +43,8 @@ func (f *fakeChatAccessor) IsParticipant(ctx context.Context, conversationID, ch
|
||||
return f.participant, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatAccessor) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (conversation.ChatReadAccess, error) {
|
||||
return conversation.ChatReadAccess{}, nil
|
||||
func (f *fakeChatAccessor) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (conversation.ConversationReadAccess, error) {
|
||||
return conversation.ConversationReadAccess{}, nil
|
||||
}
|
||||
|
||||
type fakeAdminChecker struct {
|
||||
@@ -180,7 +180,7 @@ func TestExecutor_CallTool_ChatNotFound(t *testing.T) {
|
||||
|
||||
func TestExecutor_CallTool_BotMismatch(t *testing.T) {
|
||||
accessor := &fakeChatAccessor{
|
||||
chat: conversation.Chat{BotID: "other-bot", ID: "c1"},
|
||||
chat: conversation.Conversation{BotID: "other-bot", ID: "c1"},
|
||||
}
|
||||
searcher := &fakeSearcher{}
|
||||
exec := NewExecutor(nil, searcher, accessor, nil)
|
||||
@@ -196,7 +196,7 @@ func TestExecutor_CallTool_BotMismatch(t *testing.T) {
|
||||
|
||||
func TestExecutor_CallTool_NotParticipant(t *testing.T) {
|
||||
accessor := &fakeChatAccessor{
|
||||
chat: conversation.Chat{BotID: "bot1", ID: "c1"},
|
||||
chat: conversation.Conversation{BotID: "bot1", ID: "c1"},
|
||||
participant: false,
|
||||
}
|
||||
searcher := &fakeSearcher{}
|
||||
@@ -216,7 +216,7 @@ func TestExecutor_CallTool_AdminBypass(t *testing.T) {
|
||||
resp: memory.SearchResponse{Results: []memory.MemoryItem{{ID: "id1", Memory: "m1", Score: 0.8}}},
|
||||
}
|
||||
accessor := &fakeChatAccessor{
|
||||
chat: conversation.Chat{BotID: "bot1", ID: "c1"},
|
||||
chat: conversation.Conversation{BotID: "bot1", ID: "c1"},
|
||||
participant: false,
|
||||
}
|
||||
admin := &fakeAdminChecker{admin: true}
|
||||
|
||||
@@ -91,7 +91,10 @@ func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext,
|
||||
|
||||
route, ok := s.getRoute(botID, toolName)
|
||||
if !ok {
|
||||
_, _ = s.ListTools(ctx, session)
|
||||
// Refresh route cache; result intentionally discarded.
|
||||
if _, err := s.ListTools(ctx, session); err != nil {
|
||||
s.logger.Warn("federation: refresh tools cache failed", slog.Any("error", err))
|
||||
}
|
||||
route, ok = s.getRoute(botID, toolName)
|
||||
if !ok {
|
||||
return nil, mcpgw.ErrToolNotFound
|
||||
|
||||
@@ -2,6 +2,7 @@ package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@@ -106,7 +107,7 @@ func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionCo
|
||||
}
|
||||
result, err := executor.CallTool(ctx, session, toolName, arguments)
|
||||
if err != nil {
|
||||
if err == ErrToolNotFound {
|
||||
if errors.Is(err, ErrToolNotFound) {
|
||||
return BuildToolErrorResult("tool not found: " + toolName), nil
|
||||
}
|
||||
return BuildToolErrorResult(err.Error()), nil
|
||||
|
||||
@@ -349,7 +349,9 @@ func parseJSONMap(data []byte) map[string]any {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(data, &m)
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err))
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
|
||||
@@ -314,7 +314,7 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
|
||||
|
||||
func convertToGetResponse(dbModel sqlc.Model) GetResponse {
|
||||
resp := GetResponse{
|
||||
ModelId: dbModel.ModelID,
|
||||
ModelID: dbModel.ModelID,
|
||||
Model: Model{
|
||||
ModelID: dbModel.ModelID,
|
||||
IsMultimodal: dbModel.IsMultimodal,
|
||||
|
||||
@@ -83,7 +83,7 @@ func ExampleService_UpdateByModelID() {
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// fmt.Printf("Updated model: %s\n", resp.ModelId)
|
||||
// fmt.Printf("Updated model: %s\n", resp.ModelID)
|
||||
}
|
||||
|
||||
func ExampleService_DeleteByModelID() {
|
||||
@@ -202,74 +202,3 @@ func TestModelTypes(t *testing.T) {
|
||||
assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope)
|
||||
})
|
||||
}
|
||||
|
||||
// Integration test example (requires actual database)
|
||||
// func TestService_Integration(t *testing.T) {
|
||||
// if testing.Short() {
|
||||
// t.Skip("Skipping integration test")
|
||||
// }
|
||||
//
|
||||
// ctx := context.Background()
|
||||
//
|
||||
// // Setup database connection
|
||||
// pool, err := db.Open(ctx, config.PostgresConfig{
|
||||
// Host: "localhost",
|
||||
// Port: 5432,
|
||||
// User: "test",
|
||||
// Password: "test",
|
||||
// Database: "test_db",
|
||||
// SSLMode: "disable",
|
||||
// })
|
||||
// require.NoError(t, err)
|
||||
// defer pool.Close()
|
||||
//
|
||||
// queries := sqlc.New(pool)
|
||||
// service := models.NewService(queries)
|
||||
//
|
||||
// // Test Create
|
||||
// createReq := models.AddRequest{
|
||||
// ModelID: "test-gpt-4",
|
||||
// Name: "Test GPT-4",
|
||||
// BaseURL: "https://api.openai.com/v1",
|
||||
// APIKey: "sk-test",
|
||||
// ClientType: models.ClientTypeOpenAI,
|
||||
// Type: models.ModelTypeChat,
|
||||
// }
|
||||
// createResp, err := service.Create(ctx, createReq)
|
||||
// require.NoError(t, err)
|
||||
// assert.NotEmpty(t, createResp.ID)
|
||||
// assert.Equal(t, "test-gpt-4", createResp.ModelID)
|
||||
//
|
||||
// // Test GetByModelID
|
||||
// getResp, err := service.GetByModelID(ctx, "test-gpt-4")
|
||||
// require.NoError(t, err)
|
||||
// assert.Equal(t, "test-gpt-4", getResp.ModelID)
|
||||
// assert.Equal(t, "Test GPT-4", getResp.Name)
|
||||
//
|
||||
// // Test List
|
||||
// models, err := service.List(ctx)
|
||||
// require.NoError(t, err)
|
||||
// assert.NotEmpty(t, models)
|
||||
//
|
||||
// // Test Update
|
||||
// updateReq := models.UpdateRequest{
|
||||
// ModelID: "test-gpt-4",
|
||||
// Name: "Updated GPT-4",
|
||||
// BaseURL: "https://api.openai.com/v1",
|
||||
// APIKey: "sk-test-updated",
|
||||
// ClientType: models.ClientTypeOpenAI,
|
||||
// Type: models.ModelTypeChat,
|
||||
// }
|
||||
// updateResp, err := service.UpdateByModelID(ctx, "test-gpt-4", updateReq)
|
||||
// require.NoError(t, err)
|
||||
// assert.Equal(t, "Updated GPT-4", updateResp.Name)
|
||||
//
|
||||
// // Test Count
|
||||
// count, err := service.Count(ctx)
|
||||
// require.NoError(t, err)
|
||||
// assert.Greater(t, count, int64(0))
|
||||
//
|
||||
// // Test Delete
|
||||
// err = service.DeleteByModelID(ctx, "test-gpt-4")
|
||||
// require.NoError(t, err)
|
||||
// }
|
||||
|
||||
@@ -75,7 +75,7 @@ type GetRequest struct {
|
||||
}
|
||||
|
||||
type GetResponse struct {
|
||||
ModelId string `json:"model_id"`
|
||||
ModelID string `json:"model_id"`
|
||||
Model
|
||||
}
|
||||
|
||||
|
||||
@@ -90,13 +90,13 @@ func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) {
|
||||
|
||||
func normalizeKey(row sqlc.BotPreauthKey) Key {
|
||||
return Key{
|
||||
ID: row.ID.String(),
|
||||
BotID: row.BotID.String(),
|
||||
Token: strings.TrimSpace(row.Token),
|
||||
IssuedByChannelIdentityID: row.IssuedByUserID.String(),
|
||||
ExpiresAt: timeFromPg(row.ExpiresAt),
|
||||
UsedAt: timeFromPg(row.UsedAt),
|
||||
CreatedAt: timeFromPg(row.CreatedAt),
|
||||
ID: row.ID.String(),
|
||||
BotID: row.BotID.String(),
|
||||
Token: strings.TrimSpace(row.Token),
|
||||
IssuedByUserID: row.IssuedByUserID.String(),
|
||||
ExpiresAt: timeFromPg(row.ExpiresAt),
|
||||
UsedAt: timeFromPg(row.UsedAt),
|
||||
CreatedAt: timeFromPg(row.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ import "time"
|
||||
|
||||
// Key represents a bot pre-authorization key.
|
||||
type Key struct {
|
||||
ID string
|
||||
BotID string
|
||||
Token string
|
||||
IssuedByChannelIdentityID string
|
||||
ExpiresAt time.Time
|
||||
UsedAt time.Time
|
||||
CreatedAt time.Time
|
||||
ID string
|
||||
BotID string
|
||||
Token string
|
||||
IssuedByUserID string
|
||||
ExpiresAt time.Time
|
||||
UsedAt time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
@@ -211,7 +211,9 @@ func (s *Service) CountByClientType(ctx context.Context, clientType ClientType)
|
||||
func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
|
||||
var metadata map[string]any
|
||||
if len(provider.Metadata) > 0 {
|
||||
_ = json.Unmarshal(provider.Metadata, &metadata)
|
||||
if err := json.Unmarshal(provider.Metadata, &metadata); err != nil {
|
||||
slog.Warn("provider metadata unmarshal failed", slog.String("id", provider.ID.String()), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Mask API key (show only first 8 characters)
|
||||
|
||||
@@ -187,7 +187,9 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sch
|
||||
if err != nil {
|
||||
return Schedule{}, err
|
||||
}
|
||||
s.rescheduleJob(updated)
|
||||
if err := s.rescheduleJob(updated); err != nil {
|
||||
return Schedule{}, fmt.Errorf("reschedule job: %w", err)
|
||||
}
|
||||
return toSchedule(updated), nil
|
||||
}
|
||||
|
||||
@@ -287,7 +289,9 @@ func (s *Service) scheduleJob(schedule sqlc.Schedule) error {
|
||||
return fmt.Errorf("schedule id missing")
|
||||
}
|
||||
job := func() {
|
||||
_ = s.runSchedule(context.Background(), toSchedule(schedule))
|
||||
if err := s.runSchedule(context.Background(), toSchedule(schedule)); err != nil {
|
||||
s.logger.Error("scheduled job failed", slog.String("schedule_id", schedule.ID.String()), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
entryID, err := s.cron.AddFunc(schedule.Pattern, job)
|
||||
if err != nil {
|
||||
@@ -299,15 +303,16 @@ func (s *Service) scheduleJob(schedule sqlc.Schedule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) rescheduleJob(schedule sqlc.Schedule) {
|
||||
func (s *Service) rescheduleJob(schedule sqlc.Schedule) error {
|
||||
id := schedule.ID.String()
|
||||
if id == "" {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
s.removeJob(id)
|
||||
if schedule.Enabled {
|
||||
_ = s.scheduleJob(schedule)
|
||||
return s.scheduleJob(schedule)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) removeJob(id string) {
|
||||
|
||||
@@ -11,7 +11,6 @@ type TriggerPayload struct {
|
||||
MaxCalls *int
|
||||
Command string
|
||||
OwnerUserID string
|
||||
ChatID string
|
||||
}
|
||||
|
||||
// Triggerer triggers schedule execution for chat-related jobs.
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
@@ -28,83 +27,6 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns user-level settings.
|
||||
func (s *Service) Get(ctx context.Context, userID string) (Settings, error) {
|
||||
pgID, err := db.ParseUUID(userID)
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
row, err := s.queries.GetSettingsByUserID(ctx, pgID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return Settings{
|
||||
ChatModelID: "",
|
||||
MemoryModelID: "",
|
||||
EmbeddingModelID: "",
|
||||
MaxContextLoadTime: DefaultMaxContextLoadTime,
|
||||
Language: DefaultLanguage,
|
||||
}, nil
|
||||
}
|
||||
return Settings{}, err
|
||||
}
|
||||
return normalizeUserSetting(row), nil
|
||||
}
|
||||
|
||||
// Upsert creates or updates user-level settings.
|
||||
func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) (Settings, error) {
|
||||
if s.queries == nil {
|
||||
return Settings{}, fmt.Errorf("settings queries not configured")
|
||||
}
|
||||
pgID, err := db.ParseUUID(userID)
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
|
||||
current := Settings{
|
||||
ChatModelID: "",
|
||||
MemoryModelID: "",
|
||||
EmbeddingModelID: "",
|
||||
MaxContextLoadTime: DefaultMaxContextLoadTime,
|
||||
Language: DefaultLanguage,
|
||||
}
|
||||
existing, err := s.queries.GetSettingsByUserID(ctx, pgID)
|
||||
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return Settings{}, err
|
||||
}
|
||||
if err == nil {
|
||||
current = normalizeUserSetting(existing)
|
||||
}
|
||||
|
||||
if value := strings.TrimSpace(req.ChatModelID); value != "" {
|
||||
current.ChatModelID = value
|
||||
}
|
||||
if value := strings.TrimSpace(req.MemoryModelID); value != "" {
|
||||
current.MemoryModelID = value
|
||||
}
|
||||
if value := strings.TrimSpace(req.EmbeddingModelID); value != "" {
|
||||
current.EmbeddingModelID = value
|
||||
}
|
||||
if req.MaxContextLoadTime != nil && *req.MaxContextLoadTime > 0 {
|
||||
current.MaxContextLoadTime = *req.MaxContextLoadTime
|
||||
}
|
||||
if strings.TrimSpace(req.Language) != "" {
|
||||
current.Language = strings.TrimSpace(req.Language)
|
||||
}
|
||||
|
||||
_, err = s.queries.UpsertUserSettings(ctx, sqlc.UpsertUserSettingsParams{
|
||||
ID: pgID,
|
||||
ChatModelID: pgtype.Text{String: current.ChatModelID, Valid: current.ChatModelID != ""},
|
||||
MemoryModelID: pgtype.Text{String: current.MemoryModelID, Valid: current.MemoryModelID != ""},
|
||||
EmbeddingModelID: pgtype.Text{String: current.EmbeddingModelID, Valid: current.EmbeddingModelID != ""},
|
||||
MaxContextLoadTime: int32(current.MaxContextLoadTime),
|
||||
Language: current.Language,
|
||||
})
|
||||
if err != nil {
|
||||
return Settings{}, err
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) {
|
||||
pgID, err := db.ParseUUID(botID)
|
||||
if err != nil {
|
||||
@@ -198,23 +120,6 @@ func (s *Service) Delete(ctx context.Context, botID string) error {
|
||||
return s.queries.DeleteSettingsByBotID(ctx, pgID)
|
||||
}
|
||||
|
||||
func normalizeUserSetting(row sqlc.GetSettingsByUserIDRow) Settings {
|
||||
settings := Settings{
|
||||
ChatModelID: strings.TrimSpace(row.ChatModelID.String),
|
||||
MemoryModelID: strings.TrimSpace(row.MemoryModelID.String),
|
||||
EmbeddingModelID: strings.TrimSpace(row.EmbeddingModelID.String),
|
||||
MaxContextLoadTime: int(row.MaxContextLoadTime),
|
||||
Language: strings.TrimSpace(row.Language),
|
||||
}
|
||||
if settings.MaxContextLoadTime <= 0 {
|
||||
settings.MaxContextLoadTime = DefaultMaxContextLoadTime
|
||||
}
|
||||
if settings.Language == "" {
|
||||
settings.Language = DefaultLanguage
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
func normalizeBotSetting(maxContextLoadTime int32, language string, allowGuest bool) Settings {
|
||||
settings := Settings{
|
||||
MaxContextLoadTime: int(maxContextLoadTime),
|
||||
@@ -277,4 +182,3 @@ func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.
|
||||
}
|
||||
return row.ID, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -62,13 +62,6 @@ type ScheduleListResponse = {
|
||||
items: Schedule[]
|
||||
}
|
||||
|
||||
type Settings = {
|
||||
chat_model_id: string
|
||||
memory_model_id: string
|
||||
embedding_model_id: string
|
||||
max_context_load_time: number
|
||||
language: string
|
||||
}
|
||||
|
||||
type Bot = {
|
||||
id: string
|
||||
@@ -280,18 +273,6 @@ configCmd.action(async () => {
|
||||
const config = readConfig()
|
||||
console.log(`host = "${config.host}"`)
|
||||
console.log(`port = ${config.port}`)
|
||||
const token = readToken()
|
||||
if (!token?.access_token) return
|
||||
try {
|
||||
const settings = await apiRequest<Settings>('/settings', {}, token)
|
||||
console.log(`chat_model_id = "${settings.chat_model_id || ''}"`)
|
||||
console.log(`memory_model_id = "${settings.memory_model_id || ''}"`)
|
||||
console.log(`embedding_model_id = "${settings.embedding_model_id || ''}"`)
|
||||
console.log(`max_context_load_time = ${settings.max_context_load_time}`)
|
||||
console.log(`language = "${settings.language}"`)
|
||||
} catch (err: unknown) {
|
||||
console.log(chalk.yellow(`Unable to load settings: ${getErrorMessage(err)}`))
|
||||
}
|
||||
})
|
||||
|
||||
configCmd
|
||||
@@ -299,33 +280,12 @@ configCmd
|
||||
.description('Update config')
|
||||
.option('--host <host>')
|
||||
.option('--port <port>')
|
||||
.option('--chat_model_id <model_id>')
|
||||
.option('--memory_model_id <model_id>')
|
||||
.option('--embedding_model_id <model_id>')
|
||||
.option('--max_context_load_time <minutes>')
|
||||
.option('--language <language>')
|
||||
.action(async (opts) => {
|
||||
const current = readConfig()
|
||||
let host = opts.host
|
||||
let port = opts.port ? Number.parseInt(opts.port, 10) : undefined
|
||||
let maxContextLoadTime: number | undefined
|
||||
if (opts.max_context_load_time !== undefined) {
|
||||
const parsed = Number.parseInt(opts.max_context_load_time, 10)
|
||||
if (Number.isNaN(parsed) || parsed <= 0) {
|
||||
console.log(chalk.red('max_context_load_time must be a positive integer.'))
|
||||
process.exit(1)
|
||||
}
|
||||
maxContextLoadTime = parsed
|
||||
}
|
||||
let language = opts.language
|
||||
const hasSettingsInput = opts.max_context_load_time !== undefined
|
||||
|| opts.language !== undefined
|
||||
|| opts.chat_model_id !== undefined
|
||||
|| opts.memory_model_id !== undefined
|
||||
|| opts.embedding_model_id !== undefined
|
||||
const hasConfigInput = Boolean(host || port)
|
||||
|
||||
if (!hasConfigInput && !hasSettingsInput) {
|
||||
if (!host && !port) {
|
||||
const answers = await inquirer.prompt([
|
||||
{ type: 'input', name: 'host', message: 'Host:', default: current.host },
|
||||
{ type: 'input', name: 'port', message: 'Port:', default: current.port },
|
||||
@@ -337,31 +297,8 @@ configCmd
|
||||
if (host) current.host = host
|
||||
if (port && !Number.isNaN(port)) current.port = port
|
||||
|
||||
if (host || (port && !Number.isNaN(port))) {
|
||||
writeConfig(current)
|
||||
console.log(chalk.green('Config updated'))
|
||||
}
|
||||
|
||||
if (hasSettingsInput) {
|
||||
if (language) {
|
||||
language = String(language).trim()
|
||||
}
|
||||
const payload: Partial<Settings> = {}
|
||||
if (opts.chat_model_id) payload.chat_model_id = String(opts.chat_model_id).trim()
|
||||
if (opts.memory_model_id) payload.memory_model_id = String(opts.memory_model_id).trim()
|
||||
if (opts.embedding_model_id) payload.embedding_model_id = String(opts.embedding_model_id).trim()
|
||||
if (maxContextLoadTime !== undefined) payload.max_context_load_time = maxContextLoadTime
|
||||
if (language) payload.language = language
|
||||
const token = ensureAuth()
|
||||
const spinner = ora('Updating settings...').start()
|
||||
try {
|
||||
await apiRequest('/settings', { method: 'PUT', body: JSON.stringify(payload) }, token)
|
||||
spinner.succeed('Settings updated')
|
||||
} catch (err: unknown) {
|
||||
spinner.fail(getErrorMessage(err) || 'Failed to update settings')
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
writeConfig(current)
|
||||
console.log(chalk.green('Config updated'))
|
||||
})
|
||||
|
||||
const provider = program.command('provider').description('Provider management')
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -374,11 +374,6 @@ export type HandlersEmbeddingsUsage = {
|
||||
video_tokens?: number;
|
||||
};
|
||||
|
||||
export type HandlersEnableModelRequest = {
|
||||
as?: string;
|
||||
model_id?: string;
|
||||
};
|
||||
|
||||
export type HandlersErrorResponse = {
|
||||
message?: string;
|
||||
};
|
||||
@@ -2923,42 +2918,6 @@ export type GetModelsCountResponses = {
|
||||
|
||||
export type GetModelsCountResponse = GetModelsCountResponses[keyof GetModelsCountResponses];
|
||||
|
||||
export type PostModelsEnableData = {
|
||||
/**
|
||||
* Enable model payload
|
||||
*/
|
||||
body: HandlersEnableModelRequest;
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/models/enable';
|
||||
};
|
||||
|
||||
export type PostModelsEnableErrors = {
|
||||
/**
|
||||
* Bad Request
|
||||
*/
|
||||
400: HandlersErrorResponse;
|
||||
/**
|
||||
* Not Found
|
||||
*/
|
||||
404: HandlersErrorResponse;
|
||||
/**
|
||||
* Internal Server Error
|
||||
*/
|
||||
500: HandlersErrorResponse;
|
||||
};
|
||||
|
||||
export type PostModelsEnableError = PostModelsEnableErrors[keyof PostModelsEnableErrors];
|
||||
|
||||
export type PostModelsEnableResponses = {
|
||||
/**
|
||||
* OK
|
||||
*/
|
||||
200: SettingsSettings;
|
||||
};
|
||||
|
||||
export type PostModelsEnableResponse = PostModelsEnableResponses[keyof PostModelsEnableResponses];
|
||||
|
||||
export type DeleteModelsModelByModelIdData = {
|
||||
body?: never;
|
||||
path: {
|
||||
|
||||
@@ -2640,52 +2640,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/enable": {
|
||||
"post": {
|
||||
"description": "Update the current user's settings to use the selected model",
|
||||
"tags": [
|
||||
"models"
|
||||
],
|
||||
"summary": "Enable model for chat/memory/embedding",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Enable model payload",
|
||||
"name": "payload",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.EnableModelRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/settings.Settings"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/model/{modelId}": {
|
||||
"get": {
|
||||
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
|
||||
@@ -4754,17 +4708,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"handlers.EnableModelRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"as": {
|
||||
"type": "string"
|
||||
},
|
||||
"model_id": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"handlers.ErrorResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -2631,52 +2631,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/enable": {
|
||||
"post": {
|
||||
"description": "Update the current user's settings to use the selected model",
|
||||
"tags": [
|
||||
"models"
|
||||
],
|
||||
"summary": "Enable model for chat/memory/embedding",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Enable model payload",
|
||||
"name": "payload",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.EnableModelRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/settings.Settings"
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"description": "Bad Request",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Not Found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/handlers.ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/models/model/{modelId}": {
|
||||
"get": {
|
||||
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
|
||||
@@ -4745,17 +4699,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"handlers.EnableModelRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"as": {
|
||||
"type": "string"
|
||||
},
|
||||
"model_id": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"handlers.ErrorResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -629,13 +629,6 @@ definitions:
|
||||
video_tokens:
|
||||
type: integer
|
||||
type: object
|
||||
handlers.EnableModelRequest:
|
||||
properties:
|
||||
as:
|
||||
type: string
|
||||
model_id:
|
||||
type: string
|
||||
type: object
|
||||
handlers.ErrorResponse:
|
||||
properties:
|
||||
message:
|
||||
@@ -3023,36 +3016,6 @@ paths:
|
||||
summary: Get model count
|
||||
tags:
|
||||
- models
|
||||
/models/enable:
|
||||
post:
|
||||
description: Update the current user's settings to use the selected model
|
||||
parameters:
|
||||
- description: Enable model payload
|
||||
in: body
|
||||
name: payload
|
||||
required: true
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.EnableModelRequest'
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
$ref: '#/definitions/settings.Settings'
|
||||
"400":
|
||||
description: Bad Request
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
"404":
|
||||
description: Not Found
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/handlers.ErrorResponse'
|
||||
summary: Enable model for chat/memory/embedding
|
||||
tags:
|
||||
- models
|
||||
/models/model/{modelId}:
|
||||
delete:
|
||||
description: Delete a model configuration by its model_id field (e.g., gpt-4)
|
||||
|
||||
Reference in New Issue
Block a user