refactor(core): codebase quality cleanup

- Remove user-level model settings (chat_model_id, memory_model_id,
  embedding_model_id, max_context_load_time, language) from users table
- Merge migration 0002 into 0001, remove compatibility migrations
- Delete dead conversation/resolver.go (1177 lines, only flow/resolver.go used)
- Remove type aliases (Chat=Conversation, types_alias.go)
- Fix SQL: remove AND false stub, fix UpdateChatTitle model_id,
  reset model IDs in DeleteSettings, add preauth expiry filter,
  add ListMessages limit, remove 10 dead queries
- Extract shared handler helpers (RequireChannelIdentityID, AuthorizeBotAccess)
- Rename internal/router to internal/channel/inbound
- Fix identity confusion: remove UserID->ChannelIdentityID fallbacks
- Fix all _ = var patterns with proper error logging
- Fix error propagation: storeMessages, rescheduleJob, botContainerID
- Fix naming: ModelId->ModelID, active->is_active, Duration semantic fix
- Remove dead code: mcpService, ReplyTarget, callMCPServer, sshShellQuote,
  buildSessionMetadata, ChatRequest.Language, TriggerPayload.ChatID
- Fix code quality: errors.Is(), remove goto, CreateHuman deprecated
- Remove Enable model endpoint and user-level settings CLI commands
- Regenerate sqlc, swagger, SDK
This commit is contained in:
BBQ
2026-02-12 23:43:29 +08:00
parent 57dd75ff52
commit 85251a2905
87 changed files with 509 additions and 2994 deletions
+6 -6
View File
@@ -50,7 +50,7 @@ import (
"github.com/memohai/memoh/internal/policy"
"github.com/memohai/memoh/internal/preauth"
"github.com/memohai/memoh/internal/providers"
"github.com/memohai/memoh/internal/router"
"github.com/memohai/memoh/internal/channel/inbound"
"github.com/memohai/memoh/internal/schedule"
"github.com/memohai/memoh/internal/server"
"github.com/memohai/memoh/internal/settings"
@@ -305,8 +305,8 @@ func provideScheduleTriggerer(resolver *flow.Resolver) schedule.Triggerer {
// conversation flow
// ---------------------------------------------------------------------------
func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, mcpConnService *mcp.ConnectionService, containerdHandler *handlers.ContainerdHandler) *flow.Resolver {
resolver := flow.NewResolver(log, modelsService, queries, memoryService, chatService, msgService, settingsService, mcpConnService, cfg.AgentGateway.BaseURL(), 120*time.Second)
func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, memoryService *memory.Service, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, containerdHandler *handlers.ContainerdHandler) *flow.Resolver {
resolver := flow.NewResolver(log, modelsService, queries, memoryService, chatService, msgService, settingsService, cfg.AgentGateway.BaseURL(), 120*time.Second)
resolver.SetSkillLoader(&skillLoaderAdapter{handler: containerdHandler})
return resolver
}
@@ -324,11 +324,11 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub) *channel.Regi
return registry
}
func provideChannelRouter(log *slog.Logger, registry *channel.Registry, routeService *route.DBService, msgService *message.DBService, resolver *flow.Resolver, identityService *identities.Service, botService *bots.Service, policyService *policy.Service, preauthService *preauth.Service, bindService *bind.Service, rc *boot.RuntimeConfig) *router.ChannelInboundProcessor {
return router.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, botService, policyService, preauthService, bindService, rc.JwtSecret, 5*time.Minute)
func provideChannelRouter(log *slog.Logger, registry *channel.Registry, routeService *route.DBService, msgService *message.DBService, resolver *flow.Resolver, identityService *identities.Service, botService *bots.Service, policyService *policy.Service, preauthService *preauth.Service, bindService *bind.Service, rc *boot.RuntimeConfig) *inbound.ChannelInboundProcessor {
return inbound.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, botService, policyService, preauthService, bindService, rc.JwtSecret, 5*time.Minute)
}
func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, channelRouter *router.ChannelInboundProcessor) *channel.Manager {
func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelService *channel.Service, channelRouter *inbound.ChannelInboundProcessor) *channel.Manager {
mgr := channel.NewManager(log, registry, channelService, channelRouter)
if mw := channelRouter.IdentityMiddleware(); mw != nil {
mgr.Use(mw)
+1 -5
View File
@@ -19,11 +19,6 @@ CREATE TABLE IF NOT EXISTS users (
avatar_url TEXT,
data_root TEXT,
last_login_at TIMESTAMPTZ,
chat_model_id TEXT,
memory_model_id TEXT,
embedding_model_id TEXT,
max_context_load_time INTEGER NOT NULL DEFAULT 1440,
language TEXT NOT NULL DEFAULT 'auto',
is_active BOOLEAN NOT NULL DEFAULT true,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
@@ -39,6 +34,7 @@ CREATE TABLE IF NOT EXISTS channel_identities (
channel_type TEXT NOT NULL,
channel_subject_id TEXT NOT NULL,
display_name TEXT,
avatar_url TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
@@ -1,2 +0,0 @@
-- 0002_channel_identity_avatar (down)
ALTER TABLE channel_identities DROP COLUMN IF EXISTS avatar_url;
@@ -1,3 +0,0 @@
-- 0002_channel_identity_avatar
-- Add avatar_url column to channel_identities for sender profile display.
ALTER TABLE channel_identities ADD COLUMN IF NOT EXISTS avatar_url TEXT;
-5
View File
@@ -43,8 +43,3 @@ SET user_id = $2, updated_at = now()
WHERE id = $1
RETURNING id, user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata, created_at, updated_at;
-- name: ClearChannelIdentityLinkedUser :one
UPDATE channel_identities
SET user_id = NULL, updated_at = now()
WHERE id = $1
RETURNING id, user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata, created_at, updated_at;
-3
View File
@@ -29,9 +29,6 @@ ON CONFLICT (container_id) DO UPDATE SET
last_stopped_at = EXCLUDED.last_stopped_at,
updated_at = now();
-- name: GetContainerByContainerID :one
SELECT * FROM containers WHERE container_id = sqlc.arg(container_id);
-- name: GetContainerByBotID :one
SELECT * FROM containers WHERE bot_id = sqlc.arg(bot_id) ORDER BY updated_at DESC LIMIT 1;
+21 -15
View File
@@ -59,6 +59,7 @@ SELECT
b.display_name AS title,
b.owner_user_id AS created_by_user_id,
b.metadata AS metadata,
chat_models.model_id AS model_id,
b.created_at,
b.updated_at,
'participant'::text AS access_mode,
@@ -69,6 +70,7 @@ SELECT
NULL::timestamptz AS last_observed_at
FROM bots b
LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = sqlc.arg(user_id)
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
WHERE b.id = sqlc.arg(bot_id)
AND (b.owner_user_id = sqlc.arg(user_id) OR bm.user_id IS NOT NULL)
ORDER BY b.updated_at DESC;
@@ -102,25 +104,29 @@ SELECT
FROM bots b
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
WHERE b.id = $1
AND false
ORDER BY b.created_at DESC;
-- name: UpdateChatTitle :one
UPDATE bots
SET display_name = sqlc.arg(title),
updated_at = now()
WHERE id = sqlc.arg(id)
RETURNING
id,
id AS bot_id,
CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind,
WITH updated AS (
UPDATE bots
SET display_name = sqlc.arg(title),
updated_at = now()
WHERE bots.id = sqlc.arg(bot_id)
RETURNING *
)
SELECT
updated.id AS id,
updated.id AS bot_id,
CASE WHEN updated.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
NULL::uuid AS parent_chat_id,
display_name AS title,
owner_user_id AS created_by_user_id,
metadata,
NULL::text AS model_id,
created_at,
updated_at;
updated.display_name AS title,
updated.owner_user_id AS created_by_user_id,
updated.metadata,
chat_models.model_id AS model_id,
updated.created_at,
updated.updated_at
FROM updated
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id;
-- name: TouchChat :exec
UPDATE bots
-3
View File
@@ -6,6 +6,3 @@ VALUES (
sqlc.arg(event_type),
sqlc.arg(payload)
);
-- name: ListLifecycleEventsByContainerID :many
SELECT * FROM lifecycle_events WHERE container_id = sqlc.arg(container_id) ORDER BY created_at ASC;
+2 -1
View File
@@ -56,7 +56,8 @@ SELECT
FROM bot_history_messages m
LEFT JOIN channel_identities ci ON ci.id = m.sender_channel_identity_id
WHERE m.bot_id = sqlc.arg(bot_id)
ORDER BY m.created_at ASC;
ORDER BY m.created_at ASC
LIMIT 10000;
-- name: ListMessagesSince :many
SELECT
-20
View File
@@ -136,29 +136,9 @@ VALUES (
)
RETURNING *;
-- name: GetModelVariantByID :one
SELECT * FROM model_variants WHERE id = sqlc.arg(id);
-- name: ListModelVariantsByModelUUID :many
SELECT * FROM model_variants
WHERE model_uuid = sqlc.arg(model_uuid)
ORDER BY weight DESC, created_at DESC;
-- name: ListModelVariantsByVariantID :many
SELECT * FROM model_variants
WHERE variant_id = sqlc.arg(variant_id)
ORDER BY created_at DESC;
-- name: UpdateModelVariant :one
UPDATE model_variants
SET
variant_id = sqlc.arg(variant_id),
weight = sqlc.arg(weight),
metadata = sqlc.arg(metadata),
updated_at = now()
WHERE id = sqlc.arg(id)
RETURNING *;
-- name: DeleteModelVariant :exec
DELETE FROM model_variants WHERE id = sqlc.arg(id);
+2
View File
@@ -7,6 +7,8 @@ RETURNING id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at;
SELECT id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at
FROM bot_preauth_keys
WHERE token = $1
AND used_at IS NULL
AND (expires_at IS NULL OR expires_at > now())
LIMIT 1;
-- name: MarkBotPreauthKeyUsed :one
+3 -16
View File
@@ -1,19 +1,3 @@
-- name: GetSettingsByUserID :one
SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
FROM users
WHERE id = $1;
-- name: UpsertUserSettings :one
UPDATE users
SET chat_model_id = $2,
memory_model_id = $3,
embedding_model_id = $4,
max_context_load_time = $5,
language = $6,
updated_at = now()
WHERE id = $1
RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language;
-- name: GetSettingsByBotID :one
SELECT
bots.id AS bot_id,
@@ -60,5 +44,8 @@ UPDATE bots
SET max_context_load_time = 1440,
language = 'auto',
allow_guest = false,
chat_model_id = NULL,
memory_model_id = NULL,
embedding_model_id = NULL,
updated_at = now()
WHERE id = $1;
-3
View File
@@ -8,6 +8,3 @@ VALUES (
sqlc.arg(digest)
)
ON CONFLICT (id) DO NOTHING;
-- name: ListSnapshotsByContainerID :many
SELECT * FROM snapshots WHERE container_id = sqlc.arg(container_id) ORDER BY created_at ASC;
-10
View File
@@ -8,13 +8,6 @@ SELECT *
FROM users
WHERE id = $1;
-- name: UpdateUserStatus :one
UPDATE users
SET is_active = $2,
updated_at = now()
WHERE id = $1
RETURNING *;
-- name: CreateAccount :one
UPDATE users
SET username = sqlc.arg(username),
@@ -54,9 +47,6 @@ ON CONFLICT (username) DO UPDATE SET
updated_at = now()
RETURNING *;
-- name: GetAccountByUsername :one
SELECT * FROM users WHERE username = sqlc.arg(username);
-- name: GetAccountByIdentity :one
SELECT * FROM users WHERE username = sqlc.arg(identity) OR email = sqlc.arg(identity);
+4 -2
View File
@@ -189,6 +189,8 @@ func (s *Service) Create(ctx context.Context, userID string, req CreateAccountRe
}
// CreateHuman keeps compatibility with older call sites.
//
// Deprecated: use Create directly.
func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
@@ -223,7 +225,7 @@ func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAcco
if err != nil {
return Account{}, err
}
role := fmt.Sprint(existing.Role)
role := existing.Role
if req.Role != nil {
role, err = normalizeRole(*req.Role)
if err != nil {
@@ -412,7 +414,7 @@ func toAccount(row sqlc.User) Account {
ID: row.ID.String(),
Username: username,
Email: email,
Role: fmt.Sprint(row.Role),
Role: row.Role,
DisplayName: displayName,
AvatarURL: avatarURL,
IsActive: row.IsActive,
+6 -2
View File
@@ -419,7 +419,9 @@ func (s *Service) enqueueDeleteLifecycle(botID string) {
slog.String("bot_id", botID),
slog.Any("error", err),
)
_ = s.updateStatus(ctx, botID, BotStatusReady)
if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil {
s.logger.Error("revert bot status failed", slog.String("bot_id", botID), slog.Any("error", err))
}
return
}
if err := s.queries.DeleteBotByID(ctx, botUUID); err != nil {
@@ -427,7 +429,9 @@ func (s *Service) enqueueDeleteLifecycle(botID string) {
slog.String("bot_id", botID),
slog.Any("error", err),
)
_ = s.updateStatus(ctx, botID, BotStatusReady)
if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil {
s.logger.Error("revert bot status failed", slog.String("bot_id", botID), slog.Any("error", err))
}
return
}
}()
+4 -1
View File
@@ -3,6 +3,7 @@ package feishu
import (
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
@@ -26,7 +27,9 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1, botOpenID string) ch
var contentMap map[string]any
if message.Content != nil {
_ = json.Unmarshal([]byte(*message.Content), &contentMap)
if err := json.Unmarshal([]byte(*message.Content), &contentMap); err != nil {
slog.Warn("feishu inbound: unmarshal content failed", slog.Any("error", err))
}
}
isMentioned := isFeishuBotMentioned(contentMap, message.Mentions, botOpenID)
+6 -2
View File
@@ -134,8 +134,12 @@ func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamE
finalText := strings.TrimSpace(s.buf.String())
s.mu.Unlock()
if finalText != "" {
_ = s.ensureStreamMessage(ctx, finalText)
_ = s.editStreamMessage(ctx, finalText)
if err := s.ensureStreamMessage(ctx, finalText); err != nil {
slog.Warn("telegram: ensure stream message failed", slog.Any("error", err))
}
if err := s.editStreamMessage(ctx, finalText); err != nil {
slog.Warn("telegram: edit stream message failed", slog.Any("error", err))
}
}
return nil
}
+3 -1
View File
@@ -264,7 +264,9 @@ func (s *Service) LinkChannelIdentityToUser(ctx context.Context, channelIdentity
func toChannelIdentity(row sqlc.ChannelIdentity) ChannelIdentity {
var metadata map[string]any
if len(row.Metadata) > 0 {
_ = json.Unmarshal(row.Metadata, &metadata)
if err := json.Unmarshal(row.Metadata, &metadata); err != nil {
slog.Warn("unmarshal channel identity metadata failed", slog.String("id", row.ID.String()), slog.Any("error", err))
}
}
if metadata == nil {
metadata = map[string]any{}
@@ -1,4 +1,4 @@
package router
package inbound
import (
"context"
@@ -201,7 +201,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
var desc channel.Descriptor
if p.registry != nil {
desc, _ = p.registry.GetDescriptor(msg.Channel)
desc, _ = p.registry.GetDescriptor(msg.Channel) //nolint:errcheck // descriptor lookup is best-effort
}
statusInfo := channel.ProcessingStatusInfo{
BotID: identity.BotID,
@@ -269,7 +269,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel
return err
}
chunkCh, streamErrCh := p.runner.StreamChat(ctx, flow.ChatRequest{
chunkCh, streamErrCh := p.runner.StreamChat(ctx, conversation.ChatRequest{
BotID: identity.BotID,
ChatID: activeChatID,
Token: token,
@@ -1,4 +1,4 @@
package router
package inbound
import (
"context"
@@ -1,4 +1,4 @@
package router
package inbound
import (
"context"
@@ -568,23 +568,3 @@ func (r *IdentityResolver) tryLinkConfiglessChannelIdentityToUser(ctx context.Co
}
return candidateUserID
}
func buildSessionMetadata(msg channel.InboundMessage) map[string]any {
metadata := map[string]any{}
if strings.TrimSpace(msg.Source) != "" {
metadata["source"] = strings.TrimSpace(msg.Source)
}
if strings.TrimSpace(msg.Message.ID) != "" {
metadata["message_id"] = strings.TrimSpace(msg.Message.ID)
}
if strings.TrimSpace(msg.Conversation.Type) != "" {
metadata["conversation_type"] = strings.TrimSpace(msg.Conversation.Type)
}
if strings.TrimSpace(msg.Conversation.Name) != "" {
metadata["conversation_name"] = strings.TrimSpace(msg.Conversation.Name)
}
if !msg.ReceivedAt.IsZero() {
metadata["received_at"] = msg.ReceivedAt.UTC().Format(time.RFC3339Nano)
}
return metadata
}
@@ -1,4 +1,4 @@
package router
package inbound
import (
"context"
+8 -5
View File
@@ -16,7 +16,7 @@ import (
// ConversationService contains the minimal conversation behavior required by route resolution.
type ConversationService interface {
Create(ctx context.Context, botID, channelIdentityID string, req conversation.CreateRequest) (conversation.Chat, error)
Create(ctx context.Context, botID, channelIdentityID string, req conversation.CreateRequest) (conversation.Conversation, error)
IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error)
AddParticipant(ctx context.Context, conversationID, channelIdentityID, role string) (conversation.Participant, error)
}
@@ -250,11 +250,12 @@ func (s *DBService) resolveConversationCreatorChannelIdentityID(ctx context.Cont
}
return fallback
}
ownerChannelIdentityID := row.OwnerUserID.String()
if strings.TrimSpace(ownerChannelIdentityID) == "" {
// NOTE: OwnerUserID is the bot owner's user ID. Used as fallback creator for group conversations.
ownerUserID := row.OwnerUserID.String()
if strings.TrimSpace(ownerUserID) == "" {
return fallback
}
return ownerChannelIdentityID
return ownerUserID
}
func toRouteFromCreate(row sqlc.CreateChatRouteRow) Route {
@@ -357,6 +358,8 @@ func parseJSONMap(data []byte) map[string]any {
return nil
}
var m map[string]any
_ = json.Unmarshal(data, &m)
if err := json.Unmarshal(data, &m); err != nil {
slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err))
}
return m
}
+10 -6
View File
@@ -1,13 +1,17 @@
package flow
import "strings"
import (
"strings"
"github.com/memohai/memoh/internal/conversation"
)
// ExtractAssistantOutputs collects assistant-role outputs from a slice of ModelMessages.
func ExtractAssistantOutputs(messages []ModelMessage) []AssistantOutput {
func ExtractAssistantOutputs(messages []conversation.ModelMessage) []conversation.AssistantOutput {
if len(messages) == 0 {
return nil
}
outputs := make([]AssistantOutput, 0, len(messages))
outputs := make([]conversation.AssistantOutput, 0, len(messages))
for _, msg := range messages {
if msg.Role != "assistant" {
continue
@@ -17,16 +21,16 @@ func ExtractAssistantOutputs(messages []ModelMessage) []AssistantOutput {
if content == "" && len(parts) == 0 {
continue
}
outputs = append(outputs, AssistantOutput{Content: content, Parts: parts})
outputs = append(outputs, conversation.AssistantOutput{Content: content, Parts: parts})
}
return outputs
}
func filterContentParts(parts []ContentPart) []ContentPart {
func filterContentParts(parts []conversation.ContentPart) []conversation.ContentPart {
if len(parts) == 0 {
return nil
}
filtered := make([]ContentPart, 0, len(parts))
filtered := make([]conversation.ContentPart, 0, len(parts))
for _, p := range parts {
if p.HasValue() {
filtered = append(filtered, p)
+76 -135
View File
@@ -18,7 +18,6 @@ import (
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/mcp"
"github.com/memohai/memoh/internal/memory"
messagepkg "github.com/memohai/memoh/internal/message"
"github.com/memohai/memoh/internal/models"
@@ -60,7 +59,6 @@ type Resolver struct {
conversationSvc ConversationSettingsReader
messageService messagepkg.Service
settingsService *settings.Service
mcpService *mcp.ConnectionService
skillLoader SkillLoader
gatewayBaseURL string
timeout time.Duration
@@ -78,7 +76,6 @@ func NewResolver(
conversationSvc ConversationSettingsReader,
messageService messagepkg.Service,
settingsService *settings.Service,
mcpService *mcp.ConnectionService,
gatewayBaseURL string,
timeout time.Duration,
) *Resolver {
@@ -96,7 +93,6 @@ func NewResolver(
conversationSvc: conversationSvc,
messageService: messageService,
settingsService: settingsService,
mcpService: mcpService,
gatewayBaseURL: gatewayBaseURL,
timeout: timeout,
logger: log.With(slog.String("service", "conversation_resolver")),
@@ -126,7 +122,6 @@ type gatewayIdentity struct {
ChannelIdentityID string `json:"channelIdentityId"`
DisplayName string `json:"displayName"`
CurrentPlatform string `json:"currentPlatform,omitempty"`
ReplyTarget string `json:"replyTarget,omitempty"`
SessionToken string `json:"sessionToken,omitempty"`
}
@@ -138,22 +133,22 @@ type gatewaySkill struct {
}
type gatewayRequest struct {
Model gatewayModelConfig `json:"model"`
ActiveContextTime int `json:"activeContextTime"`
Channels []string `json:"channels"`
CurrentChannel string `json:"currentChannel"`
AllowedActions []string `json:"allowedActions,omitempty"`
Messages []ModelMessage `json:"messages"`
Skills []string `json:"skills"`
UsableSkills []gatewaySkill `json:"usableSkills"`
Query string `json:"query"`
Identity gatewayIdentity `json:"identity"`
Attachments []any `json:"attachments"`
Model gatewayModelConfig `json:"model"`
ActiveContextTime int `json:"activeContextTime"`
Channels []string `json:"channels"`
CurrentChannel string `json:"currentChannel"`
AllowedActions []string `json:"allowedActions,omitempty"`
Messages []conversation.ModelMessage `json:"messages"`
Skills []string `json:"skills"`
UsableSkills []gatewaySkill `json:"usableSkills"`
Query string `json:"query,omitempty"`
Identity gatewayIdentity `json:"identity"`
Attachments []any `json:"attachments"`
}
type gatewayResponse struct {
Messages []ModelMessage `json:"messages"`
Skills []string `json:"skills"`
Messages []conversation.ModelMessage `json:"messages"`
Skills []string `json:"skills"`
}
// gatewaySchedule matches the agent gateway ScheduleModel for /chat/trigger-schedule.
@@ -168,17 +163,8 @@ type gatewaySchedule struct {
// triggerScheduleRequest is the payload for POST /chat/trigger-schedule.
type triggerScheduleRequest struct {
Model gatewayModelConfig `json:"model"`
ActiveContextTime int `json:"activeContextTime"`
Channels []string `json:"channels"`
CurrentChannel string `json:"currentChannel"`
AllowedActions []string `json:"allowedActions,omitempty"`
Messages []ModelMessage `json:"messages"`
Skills []string `json:"skills"`
UsableSkills []gatewaySkill `json:"usableSkills"`
Identity gatewayIdentity `json:"identity"`
Attachments []any `json:"attachments"`
Schedule gatewaySchedule `json:"schedule"`
gatewayRequest
Schedule gatewaySchedule `json:"schedule"`
}
// --- resolved context (shared by Chat / StreamChat / TriggerSchedule) ---
@@ -189,7 +175,7 @@ type resolvedContext struct {
provider sqlc.LlmProvider
}
func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContext, error) {
func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) {
if strings.TrimSpace(req.Query) == "" {
return resolvedContext{}, fmt.Errorf("query is required")
}
@@ -208,7 +194,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
}
// Check chat-level model override.
var chatSettings Settings
var chatSettings conversation.Settings
if r.conversationSvc != nil {
chatSettings, err = r.conversationSvc.GetSettings(ctx, req.ChatID)
if err != nil {
@@ -216,11 +202,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
}
}
userSettings, err := r.loadUserSettings(ctx, req.UserID)
if err != nil {
return resolvedContext{}, err
}
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, userSettings, chatSettings)
chatModel, provider, err := r.selectChatModel(ctx, req, botSettings, chatSettings)
if err != nil {
return resolvedContext{}, err
}
@@ -230,7 +212,7 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
}
maxCtx := coalescePositiveInt(req.MaxContextLoadTime, botSettings.MaxContextLoadTime, defaultMaxContextMinutes)
var messages []ModelMessage
var messages []conversation.ModelMessage
if !skipHistory && r.conversationSvc != nil {
messages, err = r.loadMessages(ctx, req.ChatID, maxCtx)
if err != nil {
@@ -285,10 +267,9 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
Identity: gatewayIdentity{
BotID: req.BotID,
ContainerID: containerID,
ChannelIdentityID: firstNonEmpty(req.SourceChannelIdentityID, req.UserID),
ChannelIdentityID: strings.TrimSpace(req.SourceChannelIdentityID),
DisplayName: r.resolveDisplayName(ctx, req),
CurrentPlatform: req.CurrentChannel,
ReplyTarget: "",
SessionToken: req.ChatToken,
},
Attachments: []any{},
@@ -300,19 +281,19 @@ func (r *Resolver) resolve(ctx context.Context, req ChatRequest) (resolvedContex
// --- Chat ---
// Chat sends a synchronous chat request to the agent gateway and stores the result.
func (r *Resolver) Chat(ctx context.Context, req ChatRequest) (ChatResponse, error) {
func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) {
rc, err := r.resolve(ctx, req)
if err != nil {
return ChatResponse{}, err
return conversation.ChatResponse{}, err
}
resp, err := r.postChat(ctx, rc.payload, req.Token)
if err != nil {
return ChatResponse{}, err
return conversation.ChatResponse{}, err
}
if err := r.storeRound(ctx, req, resp.Messages); err != nil {
return ChatResponse{}, err
return conversation.ChatResponse{}, err
}
return ChatResponse{
return conversation.ChatResponse{
Messages: resp.Messages,
Skills: resp.Skills,
Model: rc.model.ModelID,
@@ -331,11 +312,8 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
return fmt.Errorf("schedule command is required")
}
chatID := payload.ChatID
if strings.TrimSpace(chatID) == "" {
chatID = "schedule-" + payload.ID
}
req := ChatRequest{
chatID := "schedule-" + payload.ID
req := conversation.ChatRequest{
BotID: botID,
ChatID: chatID,
Query: payload.Command,
@@ -347,22 +325,12 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
return err
}
schedulePayload := rc.payload
schedulePayload.Identity.ChannelIdentityID = strings.TrimSpace(payload.OwnerUserID)
schedulePayload.Identity.DisplayName = "Scheduler"
triggerReq := triggerScheduleRequest{
Model: rc.payload.Model,
ActiveContextTime: rc.payload.ActiveContextTime,
Channels: rc.payload.Channels,
CurrentChannel: rc.payload.CurrentChannel,
AllowedActions: rc.payload.AllowedActions,
Messages: rc.payload.Messages,
Skills: rc.payload.Skills,
UsableSkills: rc.payload.UsableSkills,
Identity: gatewayIdentity{
BotID: rc.payload.Identity.BotID,
ContainerID: rc.payload.Identity.ContainerID,
ChannelIdentityID: strings.TrimSpace(payload.OwnerUserID),
DisplayName: "Scheduler",
},
Attachments: rc.payload.Attachments,
gatewayRequest: schedulePayload,
Schedule: gatewaySchedule{
ID: payload.ID,
Name: payload.Name,
@@ -383,8 +351,8 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc
// --- StreamChat ---
// StreamChat sends a streaming chat request to the agent gateway.
func (r *Resolver) StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error) {
chunkCh := make(chan StreamChunk)
func (r *Resolver) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) {
chunkCh := make(chan conversation.StreamChunk)
errCh := make(chan error, 1)
r.logger.Info("gateway stream start",
slog.String("bot_id", req.BotID),
@@ -513,7 +481,7 @@ func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerSched
return parsed, nil
}
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req ChatRequest, chunkCh chan<- StreamChunk) error {
func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req conversation.ChatRequest, chunkCh chan<- conversation.StreamChunk) error {
body, err := json.Marshal(payload)
if err != nil {
return err
@@ -564,7 +532,7 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req C
if data == "" || data == "[DONE]" {
continue
}
chunkCh <- StreamChunk([]byte(data))
chunkCh <- conversation.StreamChunk([]byte(data))
if stored {
continue
@@ -579,7 +547,7 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req C
}
// tryStoreStream attempts to extract final messages from a stream event and persist them.
func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventType, data string) (bool, error) {
func (r *Resolver) tryStoreStream(ctx context.Context, req conversation.ChatRequest, eventType, data string) (bool, error) {
// event: done + data: {messages: [...]}
if eventType == "done" {
var resp gatewayResponse
@@ -590,10 +558,10 @@ func (r *Resolver) tryStoreStream(ctx context.Context, req ChatRequest, eventTyp
// data: {"type":"text_delta"|"agent_end"|"done", ...}
var envelope struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
Messages []ModelMessage `json:"messages"`
Skills []string `json:"skills"`
Type string `json:"type"`
Data json.RawMessage `json:"data"`
Messages []conversation.ModelMessage `json:"messages"`
Skills []string `json:"skills"`
}
if err := json.Unmarshal([]byte(data), &envelope); err == nil {
if (envelope.Type == "agent_end" || envelope.Type == "done") && len(envelope.Messages) > 0 {
@@ -630,12 +598,13 @@ func (r *Resolver) resolveContainerID(ctx context.Context, botID, explicit strin
}
}
}
r.logger.Warn("no container found for bot, using fallback", slog.String("bot_id", botID))
return "mcp-" + botID
}
// --- message loading ---
func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]ModelMessage, error) {
func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMinutes int) ([]conversation.ModelMessage, error) {
if r.messageService == nil {
return nil, nil
}
@@ -644,12 +613,13 @@ func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMi
if err != nil {
return nil, err
}
var result []ModelMessage
var result []conversation.ModelMessage
for _, m := range msgs {
var mm ModelMessage
var mm conversation.ModelMessage
if err := json.Unmarshal(m.Content, &mm); err != nil {
// Fallback: treat content as text string.
mm = ModelMessage{Role: m.Role, Content: m.Content}
r.logger.Warn("loadMessages: content unmarshal failed, treating as raw text",
slog.String("chat_id", chatID), slog.Any("error", err))
mm = conversation.ModelMessage{Role: m.Role, Content: m.Content}
} else {
mm.Role = m.Role
}
@@ -663,7 +633,7 @@ type memoryContextItem struct {
Item memory.MemoryItem
}
func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest) *ModelMessage {
func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversation.ChatRequest) *conversation.ModelMessage {
if r.memoryService == nil {
return nil
}
@@ -680,7 +650,7 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest
Filters: map[string]any{
"namespace": sharedMemoryNamespace,
"scopeId": req.BotID,
"botId": req.BotID,
"bot_id": req.BotID,
},
})
if err != nil {
@@ -732,16 +702,16 @@ func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req ChatRequest
if payload == "" {
return nil
}
msg := ModelMessage{
msg := conversation.ModelMessage{
Role: "system",
Content: NewTextContent(payload),
Content: conversation.NewTextContent(payload),
}
return &msg
}
// --- store helpers ---
func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) error {
func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.ChatRequest) error {
if r.messageService == nil {
return nil
}
@@ -753,9 +723,9 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) erro
return nil
}
message := ModelMessage{
message := conversation.ModelMessage{
Role: "user",
Content: NewTextContent(text),
Content: conversation.NewTextContent(text),
}
content, err := json.Marshal(message)
if err != nil {
@@ -776,10 +746,10 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req ChatRequest) erro
return err
}
func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []ModelMessage) error {
func (r *Resolver) storeRound(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage) error {
// Add user query as the first message if not already present in the round.
// This ensures the user's prompt is persisted alongside the assistant's response.
fullRound := make([]ModelMessage, 0, len(messages)+1)
fullRound := make([]conversation.ModelMessage, 0, len(messages)+1)
hasUserQuery := false
for _, m := range messages {
if m.Role == "user" && m.TextContent() == req.Query {
@@ -788,9 +758,9 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M
}
}
if !req.UserMessagePersisted && !hasUserQuery && strings.TrimSpace(req.Query) != "" {
fullRound = append(fullRound, ModelMessage{
fullRound = append(fullRound, conversation.ModelMessage{
Role: "user",
Content: NewTextContent(req.Query),
Content: conversation.NewTextContent(req.Query),
})
}
for _, m := range messages {
@@ -809,7 +779,7 @@ func (r *Resolver) storeRound(ctx context.Context, req ChatRequest, messages []M
return nil
}
func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages []ModelMessage) {
func (r *Resolver) storeMessages(ctx context.Context, req conversation.ChatRequest, messages []conversation.ModelMessage) {
if r.messageService == nil {
return
}
@@ -821,6 +791,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
for _, msg := range messages {
content, err := json.Marshal(msg)
if err != nil {
r.logger.Warn("storeMessages: marshal failed", slog.Any("error", err))
continue
}
messageSenderChannelIdentityID := ""
@@ -852,7 +823,7 @@ func (r *Resolver) storeMessages(ctx context.Context, req ChatRequest, messages
}
}
func buildRouteMetadata(req ChatRequest) map[string]any {
func buildRouteMetadata(req conversation.ChatRequest) map[string]any {
if strings.TrimSpace(req.RouteID) == "" && strings.TrimSpace(req.CurrentChannel) == "" {
return nil
}
@@ -866,25 +837,17 @@ func buildRouteMetadata(req ChatRequest) map[string]any {
return meta
}
func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req ChatRequest) (string, string) {
func (r *Resolver) resolvePersistSenderIDs(ctx context.Context, req conversation.ChatRequest) (string, string) {
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
userID := strings.TrimSpace(req.UserID)
channelIdentityValid := r.isExistingChannelIdentityID(ctx, channelIdentityID)
userAsUserValid := r.isExistingUserID(ctx, userID)
userAsChannelIdentityValid := r.isExistingChannelIdentityID(ctx, userID)
senderChannelIdentityID := ""
switch {
case channelIdentityValid:
if r.isExistingChannelIdentityID(ctx, channelIdentityID) {
senderChannelIdentityID = channelIdentityID
case userAsChannelIdentityValid && !userAsUserValid:
// Some flows may carry channel_identity_id in req.UserID.
senderChannelIdentityID = userID
}
senderUserID := ""
if userAsUserValid {
if r.isExistingUserID(ctx, userID) {
senderUserID = userID
}
if senderUserID == "" && senderChannelIdentityID != "" {
@@ -936,14 +899,14 @@ func (r *Resolver) linkedUserIDFromChannelIdentity(ctx context.Context, channelI
// resolveDisplayName returns the best available display name for the request identity:
// req.DisplayName if set, else channel identity's display_name, else linked user's display_name, else "User".
func (r *Resolver) resolveDisplayName(ctx context.Context, req ChatRequest) string {
func (r *Resolver) resolveDisplayName(ctx context.Context, req conversation.ChatRequest) string {
if name := strings.TrimSpace(req.DisplayName); name != "" {
return name
}
if r.queries == nil {
return "User"
}
channelIdentityID := firstNonEmpty(req.SourceChannelIdentityID, req.UserID)
channelIdentityID := strings.TrimSpace(req.SourceChannelIdentityID)
if channelIdentityID == "" {
return "User"
}
@@ -975,7 +938,7 @@ func (r *Resolver) resolveDisplayName(ctx context.Context, req ChatRequest) stri
return "User"
}
func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []ModelMessage) {
func (r *Resolver) storeMemory(ctx context.Context, botID string, messages []conversation.ModelMessage) {
if r.memoryService == nil {
return
}
@@ -1004,7 +967,7 @@ func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Me
filters := map[string]any{
"namespace": namespace,
"scopeId": scopeID,
"botId": botID,
"bot_id": botID,
}
if _, err := r.memoryService.Add(ctx, memory.AddRequest{
Messages: msgs,
@@ -1021,21 +984,19 @@ func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Me
// --- model selection ---
func (r *Resolver) selectChatModel(ctx context.Context, req ChatRequest, botSettings settings.Settings, us resolvedUserSettings, cs Settings) (models.GetResponse, sqlc.LlmProvider, error) {
func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.LlmProvider, error) {
if r.modelsService == nil {
return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured")
}
modelID := strings.TrimSpace(req.Model)
providerFilter := strings.TrimSpace(req.Provider)
// Priority: request model > chat settings > bot settings > user settings.
// Priority: request model > chat settings > bot settings.
if modelID == "" && providerFilter == "" {
if value := strings.TrimSpace(cs.ModelID); value != "" {
modelID = value
} else if value := strings.TrimSpace(botSettings.ChatModelID); value != "" {
modelID = value
} else if value := strings.TrimSpace(us.ChatModelID); value != "" {
modelID = value
}
}
@@ -1100,29 +1061,9 @@ func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([
// --- settings ---
type resolvedUserSettings struct {
ChatModelID string
}
func (r *Resolver) loadUserSettings(ctx context.Context, userID string) (resolvedUserSettings, error) {
if r.settingsService == nil || strings.TrimSpace(userID) == "" {
return resolvedUserSettings{}, nil
}
s, err := r.settingsService.Get(ctx, userID)
if err != nil {
return resolvedUserSettings{}, err
}
return resolvedUserSettings{
ChatModelID: strings.TrimSpace(s.ChatModelID),
}, nil
}
func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) {
if r.settingsService == nil {
return settings.Settings{
MaxContextLoadTime: settings.DefaultMaxContextLoadTime,
Language: settings.DefaultLanguage,
}, nil
return settings.Settings{}, fmt.Errorf("settings service not configured")
}
return r.settingsService.GetBot(ctx, botID)
}
@@ -1140,8 +1081,8 @@ func normalizeClientType(clientType string) (string, error) {
}
}
func sanitizeMessages(messages []ModelMessage) []ModelMessage {
cleaned := make([]ModelMessage, 0, len(messages))
func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage {
cleaned := make([]conversation.ModelMessage, 0, len(messages))
for _, msg := range messages {
if strings.TrimSpace(msg.Role) == "" {
continue
@@ -1196,9 +1137,9 @@ func nonNilStrings(s []string) []string {
return s
}
func nonNilModelMessages(m []ModelMessage) []ModelMessage {
func nonNilModelMessages(m []conversation.ModelMessage) []conversation.ModelMessage {
if m == nil {
return []ModelMessage{}
return []conversation.ModelMessage{}
}
return m
}
@@ -6,6 +6,7 @@ import (
"strings"
"testing"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/memory"
)
@@ -14,7 +15,7 @@ func TestLoadMemoryContextMessage_NoMemoryService(t *testing.T) {
memoryService: nil,
logger: slog.Default(),
}
msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{
msg := resolver.loadMemoryContextMessage(context.Background(), conversation.ChatRequest{
Query: "hello",
BotID: "bot-1",
ChatID: "chat-1",
@@ -29,7 +30,7 @@ func TestLoadMemoryContextMessage_SearchFailureFallback(t *testing.T) {
memoryService: &memory.Service{},
logger: slog.Default(),
}
msg := resolver.loadMemoryContextMessage(context.Background(), ChatRequest{
msg := resolver.loadMemoryContextMessage(context.Background(), conversation.ChatRequest{
Query: "hello",
BotID: "bot-1",
ChatID: "chat-1",
+36 -28
View File
@@ -9,6 +9,8 @@ import (
"net/http/httptest"
"testing"
"time"
"github.com/memohai/memoh/internal/conversation"
)
func TestPostTriggerSchedule_Endpoint(t *testing.T) {
@@ -21,7 +23,7 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) {
capturedAuth = r.Header.Get("Authorization")
capturedBody, _ = io.ReadAll(r.Body)
resp := gatewayResponse{
Messages: []ModelMessage{{Role: "assistant", Content: NewTextContent("ok")}},
Messages: []conversation.ModelMessage{{Role: "assistant", Content: conversation.NewTextContent("ok")}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
@@ -36,23 +38,25 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) {
maxCalls := 5
req := triggerScheduleRequest{
Model: gatewayModelConfig{
ModelID: "gpt-4",
ClientType: "openai",
APIKey: "sk-test",
BaseURL: "https://api.openai.com",
gatewayRequest: gatewayRequest{
Model: gatewayModelConfig{
ModelID: "gpt-4",
ClientType: "openai",
APIKey: "sk-test",
BaseURL: "https://api.openai.com",
},
ActiveContextTime: 1440,
Channels: []string{},
Messages: []conversation.ModelMessage{},
Skills: []string{},
Identity: gatewayIdentity{
BotID: "bot-123",
ContainerID: "mcp-bot-123",
ChannelIdentityID: "owner-user-1",
DisplayName: "Scheduler",
},
Attachments: []any{},
},
ActiveContextTime: 1440,
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Identity: gatewayIdentity{
BotID: "bot-123",
ContainerID: "mcp-bot-123",
ChannelIdentityID: "owner-user-1",
DisplayName: "Scheduler",
},
Attachments: []any{},
Schedule: gatewaySchedule{
ID: "sched-1",
Name: "daily report",
@@ -102,7 +106,7 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
resp := gatewayResponse{Messages: []ModelMessage{}}
resp := gatewayResponse{Messages: []conversation.ModelMessage{}}
json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
@@ -114,11 +118,13 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) {
}
req := triggerScheduleRequest{
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Attachments: []any{},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
gatewayRequest: gatewayRequest{
Channels: []string{},
Messages: []conversation.ModelMessage{},
Skills: []string{},
Attachments: []any{},
},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "")
@@ -144,11 +150,13 @@ func TestPostTriggerSchedule_GatewayError(t *testing.T) {
}
req := triggerScheduleRequest{
Channels: []string{},
Messages: []ModelMessage{},
Skills: []string{},
Attachments: []any{},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
gatewayRequest: gatewayRequest{
Channels: []string{},
Messages: []conversation.ModelMessage{},
Skills: []string{},
Attachments: []any{},
},
Schedule: gatewaySchedule{ID: "s1", Command: "test"},
}
_, err := resolver.postTriggerSchedule(context.Background(), req, "Bearer tok")
+3 -2
View File
@@ -3,12 +3,13 @@ package flow
import (
"context"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/schedule"
)
// Runner defines conversation execution behavior for sync, stream, and scheduled flows.
type Runner interface {
Chat(ctx context.Context, req ChatRequest) (ChatResponse, error)
StreamChat(ctx context.Context, req ChatRequest) (<-chan StreamChunk, <-chan error)
Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error)
StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error)
TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error
}
-15
View File
@@ -1,15 +0,0 @@
package flow
import "github.com/memohai/memoh/internal/conversation"
type ModelMessage = conversation.ModelMessage
type ContentPart = conversation.ContentPart
type ToolCall = conversation.ToolCall
type ToolCallFunction = conversation.ToolCallFunction
type AssistantOutput = conversation.AssistantOutput
type ChatRequest = conversation.ChatRequest
type ChatResponse = conversation.ChatResponse
type StreamChunk = conversation.StreamChunk
type Settings = conversation.Settings
var NewTextContent = conversation.NewTextContent
+2 -2
View File
@@ -4,7 +4,7 @@ import "context"
// Reader defines conversation lookup behavior.
type Reader interface {
Get(ctx context.Context, conversationID string) (Chat, error)
Get(ctx context.Context, conversationID string) (Conversation, error)
}
// ParticipantChecker defines participant membership checks.
@@ -16,5 +16,5 @@ type ParticipantChecker interface {
type Accessor interface {
Reader
ParticipantChecker
GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error)
GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ConversationReadAccess, error)
}
File diff suppressed because it is too large Load Diff
@@ -40,24 +40,24 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
}
// Create creates a new conversation and adds the creator as owner.
func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Chat, error) {
func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, req CreateRequest) (Conversation, error) {
kind := strings.TrimSpace(req.Kind)
if kind == "" {
kind = KindDirect
}
if kind != KindDirect && kind != KindGroup && kind != KindThread {
return Chat{}, fmt.Errorf("invalid conversation kind: %s", kind)
return Conversation{}, fmt.Errorf("invalid conversation kind: %s", kind)
}
pgBotID, err := parseUUID(botID)
if err != nil {
return Chat{}, fmt.Errorf("invalid bot id: %w", err)
return Conversation{}, fmt.Errorf("invalid bot id: %w", err)
}
pgChannelIdentityID := pgtype.UUID{}
if strings.TrimSpace(channelIdentityID) != "" {
pgChannelIdentityID, err = parseUUID(channelIdentityID)
if err != nil {
return Chat{}, fmt.Errorf("invalid channel identity id: %w", err)
return Conversation{}, fmt.Errorf("invalid channel identity id: %w", err)
}
}
@@ -65,13 +65,13 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
if kind == KindThread && strings.TrimSpace(req.ParentChatID) != "" {
pgParent, err = parseUUID(req.ParentChatID)
if err != nil {
return Chat{}, fmt.Errorf("invalid parent conversation id: %w", err)
return Conversation{}, fmt.Errorf("invalid parent conversation id: %w", err)
}
}
metadata, err := json.Marshal(nonNilMap(req.Metadata))
if err != nil {
return Chat{}, fmt.Errorf("marshal conversation metadata: %w", err)
return Conversation{}, fmt.Errorf("marshal conversation metadata: %w", err)
}
row, err := s.queries.CreateChat(ctx, sqlc.CreateChatParams{
@@ -83,7 +83,7 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
Metadata: metadata,
})
if err != nil {
return Chat{}, fmt.Errorf("create conversation: %w", err)
return Conversation{}, fmt.Errorf("create conversation: %w", err)
}
// Add creator as owner when the channel identity is available.
@@ -93,7 +93,7 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
UserID: pgChannelIdentityID,
Role: RoleOwner,
}); err != nil {
return Chat{}, fmt.Errorf("add owner participant: %w", err)
return Conversation{}, fmt.Errorf("add owner participant: %w", err)
}
}
@@ -102,7 +102,7 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
if err := s.queries.CopyParticipantsToChat(ctx, sqlc.CopyParticipantsToChatParams{
ChatID: pgParent,
ChatID2: row.ID,
}); err != nil && s.logger != nil {
}); err != nil {
s.logger.Warn("copy parent participants failed", slog.Any("error", err))
}
}
@@ -111,30 +111,30 @@ func (s *Service) Create(ctx context.Context, botID, channelIdentityID string, r
}
// Get returns a conversation by ID.
func (s *Service) Get(ctx context.Context, conversationID string) (Chat, error) {
func (s *Service) Get(ctx context.Context, conversationID string) (Conversation, error) {
pgID, err := parseUUID(conversationID)
if err != nil {
return Chat{}, ErrChatNotFound
return Conversation{}, ErrChatNotFound
}
row, err := s.queries.GetChatByID(ctx, pgID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return Chat{}, ErrChatNotFound
return Conversation{}, ErrChatNotFound
}
return Chat{}, err
return Conversation{}, err
}
return toChatFromGet(row), nil
}
// GetReadAccess resolves whether a user can read a conversation.
func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ChatReadAccess, error) {
func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ConversationReadAccess, error) {
pgConversationID, err := parseUUID(conversationID)
if err != nil {
return ChatReadAccess{}, ErrPermissionDenied
return ConversationReadAccess{}, ErrPermissionDenied
}
pgChannelIdentityID, err := parseUUID(channelIdentityID)
if err != nil {
return ChatReadAccess{}, ErrPermissionDenied
return ConversationReadAccess{}, ErrPermissionDenied
}
row, err := s.queries.GetChatReadAccessByUser(ctx, sqlc.GetChatReadAccessByUserParams{
ChatID: pgConversationID,
@@ -142,11 +142,11 @@ func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIden
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return ChatReadAccess{}, ErrPermissionDenied
return ConversationReadAccess{}, ErrPermissionDenied
}
return ChatReadAccess{}, err
return ConversationReadAccess{}, err
}
return ChatReadAccess{
return ConversationReadAccess{
AccessMode: row.AccessMode,
ParticipantRole: strings.TrimSpace(row.ParticipantRole),
LastObservedAt: pgTimePtr(row.LastObservedAt),
@@ -154,7 +154,7 @@ func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIden
}
// ListByBotAndChannelIdentity returns all visible conversations for a bot and channel identity.
func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ChatListItem, error) {
func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ConversationListItem, error) {
pgBotID, err := parseUUID(botID)
if err != nil {
return nil, err
@@ -170,7 +170,7 @@ func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channe
if err != nil {
return nil, err
}
conversations := make([]ChatListItem, 0, len(rows))
conversations := make([]ConversationListItem, 0, len(rows))
for _, row := range rows {
conversations = append(conversations, toChatListItem(row))
}
@@ -178,7 +178,7 @@ func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channe
}
// ListThreads returns threads for a parent conversation.
func (s *Service) ListThreads(ctx context.Context, parentConversationID string) ([]Chat, error) {
func (s *Service) ListThreads(ctx context.Context, parentConversationID string) ([]Conversation, error) {
pgID, err := parseUUID(parentConversationID)
if err != nil {
return nil, err
@@ -187,7 +187,7 @@ func (s *Service) ListThreads(ctx context.Context, parentConversationID string)
if err != nil {
return nil, err
}
conversations := make([]Chat, 0, len(rows))
conversations := make([]Conversation, 0, len(rows))
for _, row := range rows {
conversations = append(conversations, toChatFromThread(row))
}
@@ -296,7 +296,7 @@ func (s *Service) RemoveParticipant(ctx context.Context, conversationID, channel
func (s *Service) GetSettings(ctx context.Context, conversationID string) (Settings, error) {
pgID, err := parseUUID(conversationID)
if err != nil {
return defaultSettings(conversationID), nil
return Settings{}, fmt.Errorf("invalid conversation id: %w", err)
}
row, err := s.queries.GetChatSettings(ctx, pgID)
if err != nil {
@@ -332,7 +332,7 @@ func (s *Service) UpdateSettings(ctx context.Context, conversationID string, req
return toSettingsFromUpsert(row), nil
}
func toChatFromCreate(row sqlc.CreateChatRow) Chat {
func toChatFromCreate(row sqlc.CreateChatRow) Conversation {
return toChatFields(
row.ID,
row.BotID,
@@ -346,7 +346,7 @@ func toChatFromCreate(row sqlc.CreateChatRow) Chat {
)
}
func toChatFromGet(row sqlc.GetChatByIDRow) Chat {
func toChatFromGet(row sqlc.GetChatByIDRow) Conversation {
return toChatFields(
row.ID,
row.BotID,
@@ -360,7 +360,7 @@ func toChatFromGet(row sqlc.GetChatByIDRow) Chat {
)
}
func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat {
func toChatFromThread(row sqlc.ListThreadsByParentRow) Conversation {
return toChatFields(
row.ID,
row.BotID,
@@ -374,8 +374,8 @@ func toChatFromThread(row sqlc.ListThreadsByParentRow) Chat {
)
}
func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Chat {
return Chat{
func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, title pgtype.Text, createdBy pgtype.UUID, metadata []byte, createdAt, updatedAt pgtype.Timestamptz) Conversation {
return Conversation{
ID: id.String(),
BotID: botID.String(),
Kind: kind,
@@ -388,8 +388,8 @@ func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID,
}
}
func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ChatListItem {
return ChatListItem{
func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ConversationListItem {
return ConversationListItem{
ID: row.ID.String(),
BotID: row.BotID.String(),
Kind: row.Kind,
@@ -478,6 +478,8 @@ func parseJSONMap(data []byte) map[string]any {
return nil
}
var m map[string]any
_ = json.Unmarshal(data, &m)
if err := json.Unmarshal(data, &m); err != nil {
slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err))
}
return m
}
@@ -175,7 +175,7 @@ func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) {
t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind))
}
var target *conversation.ChatListItem
var target *conversation.ConversationListItem
for i := range afterBind {
if afterBind[i].ID == chatID {
target = &afterBind[i]
+8 -9
View File
@@ -3,11 +3,12 @@ package conversation
import (
"encoding/json"
"log/slog"
"strings"
"time"
)
// Chat kind constants.
// Conversation kind constants.
const (
KindDirect = "direct"
KindGroup = "group"
@@ -21,7 +22,7 @@ const (
RoleMember = "member"
)
// Chat list access mode constants.
// Conversation list access mode constants.
const (
AccessModeParticipant = "participant"
AccessModeChannelIdentityObserved = "channel_identity_observed"
@@ -63,11 +64,6 @@ type ConversationReadAccess struct {
LastObservedAt *time.Time
}
// Backward-compatible aliases while call sites migrate.
type Chat = Conversation
type ChatListItem = ConversationListItem
type ChatReadAccess = ConversationReadAccess
// Participant represents a chat member.
type Participant struct {
ChatID string `json:"chat_id"`
@@ -155,7 +151,11 @@ func (m ModelMessage) HasContent() bool {
// NewTextContent creates a json.RawMessage from a plain string.
func NewTextContent(text string) json.RawMessage {
data, _ := json.Marshal(text)
data, err := json.Marshal(text)
if err != nil {
slog.Warn("NewTextContent: marshal failed", slog.Any("error", err))
return nil
}
return data
}
@@ -209,7 +209,6 @@ type ChatRequest struct {
Model string `json:"model,omitempty"`
Provider string `json:"provider,omitempty"`
MaxContextLoadTime int `json:"max_context_load_time,omitempty"`
Language string `json:"language,omitempty"`
Channels []string `json:"channels,omitempty"`
CurrentChannel string `json:"current_channel,omitempty"`
Messages []ModelMessage `json:"messages,omitempty"`
@@ -11,30 +11,6 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)
const clearChannelIdentityLinkedUser = `-- name: ClearChannelIdentityLinkedUser :one
UPDATE channel_identities
SET user_id = NULL, updated_at = now()
WHERE id = $1
RETURNING id, user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata, created_at, updated_at
`
func (q *Queries) ClearChannelIdentityLinkedUser(ctx context.Context, id pgtype.UUID) (ChannelIdentity, error) {
row := q.db.QueryRow(ctx, clearChannelIdentityLinkedUser, id)
var i ChannelIdentity
err := row.Scan(
&i.ID,
&i.UserID,
&i.ChannelType,
&i.ChannelSubjectID,
&i.DisplayName,
&i.AvatarUrl,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createChannelIdentity = `-- name: CreateChannelIdentity :one
INSERT INTO channel_identities (user_id, channel_type, channel_subject_id, display_name, avatar_url, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
-26
View File
@@ -46,32 +46,6 @@ func (q *Queries) GetContainerByBotID(ctx context.Context, botID pgtype.UUID) (C
return i, err
}
const getContainerByContainerID = `-- name: GetContainerByContainerID :one
SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE container_id = $1
`
func (q *Queries) GetContainerByContainerID(ctx context.Context, containerID string) (Container, error) {
row := q.db.QueryRow(ctx, getContainerByContainerID, containerID)
var i Container
err := row.Scan(
&i.ID,
&i.BotID,
&i.ContainerID,
&i.ContainerName,
&i.Image,
&i.Status,
&i.Namespace,
&i.AutoStart,
&i.HostPath,
&i.ContainerPath,
&i.CreatedAt,
&i.UpdatedAt,
&i.LastStartedAt,
&i.LastStoppedAt,
)
return i, err
}
const listAutoStartContainers = `-- name: ListAutoStartContainers :many
SELECT id, bot_id, container_id, container_name, image, status, namespace, auto_start, host_path, container_path, created_at, updated_at, last_started_at, last_stopped_at FROM containers WHERE auto_start = true ORDER BY updated_at DESC
`
+25 -17
View File
@@ -428,7 +428,6 @@ SELECT
FROM bots b
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
WHERE b.id = $1
AND false
ORDER BY b.created_at DESC
`
@@ -485,6 +484,7 @@ SELECT
b.display_name AS title,
b.owner_user_id AS created_by_user_id,
b.metadata AS metadata,
chat_models.model_id AS model_id,
b.created_at,
b.updated_at,
'participant'::text AS access_mode,
@@ -495,6 +495,7 @@ SELECT
NULL::timestamptz AS last_observed_at
FROM bots b
LEFT JOIN bot_members bm ON bm.bot_id = b.id AND bm.user_id = $1
LEFT JOIN models chat_models ON chat_models.id = b.chat_model_id
WHERE b.id = $2
AND (b.owner_user_id = $1 OR bm.user_id IS NOT NULL)
ORDER BY b.updated_at DESC
@@ -513,6 +514,7 @@ type ListVisibleChatsByBotAndUserRow struct {
Title pgtype.Text `json:"title"`
CreatedByUserID pgtype.UUID `json:"created_by_user_id"`
Metadata []byte `json:"metadata"`
ModelID pgtype.Text `json:"model_id"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
AccessMode string `json:"access_mode"`
@@ -537,6 +539,7 @@ func (q *Queries) ListVisibleChatsByBotAndUser(ctx context.Context, arg ListVisi
&i.Title,
&i.CreatedByUserID,
&i.Metadata,
&i.ModelID,
&i.CreatedAt,
&i.UpdatedAt,
&i.AccessMode,
@@ -582,26 +585,31 @@ func (q *Queries) TouchChat(ctx context.Context, chatID pgtype.UUID) error {
}
const updateChatTitle = `-- name: UpdateChatTitle :one
UPDATE bots
SET display_name = $1,
updated_at = now()
WHERE id = $2
RETURNING
id,
id AS bot_id,
CASE WHEN type = 'public' THEN 'group' ELSE 'direct' END AS kind,
WITH updated AS (
UPDATE bots
SET display_name = $1,
updated_at = now()
WHERE bots.id = $2
RETURNING id, owner_user_id, type, display_name, avatar_url, is_active, status, max_context_load_time, language, allow_guest, chat_model_id, memory_model_id, embedding_model_id, metadata, created_at, updated_at
)
SELECT
updated.id AS id,
updated.id AS bot_id,
CASE WHEN updated.type = 'public' THEN 'group' ELSE 'direct' END AS kind,
NULL::uuid AS parent_chat_id,
display_name AS title,
owner_user_id AS created_by_user_id,
metadata,
NULL::text AS model_id,
created_at,
updated_at
updated.display_name AS title,
updated.owner_user_id AS created_by_user_id,
updated.metadata,
chat_models.model_id AS model_id,
updated.created_at,
updated.updated_at
FROM updated
LEFT JOIN models chat_models ON chat_models.id = updated.chat_model_id
`
type UpdateChatTitleParams struct {
Title pgtype.Text `json:"title"`
ID pgtype.UUID `json:"id"`
BotID pgtype.UUID `json:"bot_id"`
}
type UpdateChatTitleRow struct {
@@ -618,7 +626,7 @@ type UpdateChatTitleRow struct {
}
func (q *Queries) UpdateChatTitle(ctx context.Context, arg UpdateChatTitleParams) (UpdateChatTitleRow, error) {
row := q.db.QueryRow(ctx, updateChatTitle, arg.Title, arg.ID)
row := q.db.QueryRow(ctx, updateChatTitle, arg.Title, arg.BotID)
var i UpdateChatTitleRow
err := row.Scan(
&i.ID,
-30
View File
@@ -35,33 +35,3 @@ func (q *Queries) InsertLifecycleEvent(ctx context.Context, arg InsertLifecycleE
)
return err
}
const listLifecycleEventsByContainerID = `-- name: ListLifecycleEventsByContainerID :many
SELECT id, container_id, event_type, payload, created_at FROM lifecycle_events WHERE container_id = $1 ORDER BY created_at ASC
`
func (q *Queries) ListLifecycleEventsByContainerID(ctx context.Context, containerID string) ([]LifecycleEvent, error) {
rows, err := q.db.Query(ctx, listLifecycleEventsByContainerID, containerID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []LifecycleEvent
for rows.Next() {
var i LifecycleEvent
if err := rows.Scan(
&i.ID,
&i.ContainerID,
&i.EventType,
&i.Payload,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+1
View File
@@ -140,6 +140,7 @@ FROM bot_history_messages m
LEFT JOIN channel_identities ci ON ci.id = m.sender_channel_identity_id
WHERE m.bot_id = $1
ORDER BY m.created_at ASC
LIMIT 10000
`
type ListMessagesRow struct {
+13 -18
View File
@@ -225,24 +225,19 @@ type Subagent struct {
}
type User struct {
ID pgtype.UUID `json:"id"`
Username pgtype.Text `json:"username"`
Email pgtype.Text `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Role string `json:"role"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
DataRoot pgtype.Text `json:"data_root"`
LastLoginAt pgtype.Timestamptz `json:"last_login_at"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
MaxContextLoadTime int32 `json:"max_context_load_time"`
Language string `json:"language"`
IsActive bool `json:"is_active"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
ID pgtype.UUID `json:"id"`
Username pgtype.Text `json:"username"`
Email pgtype.Text `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Role string `json:"role"`
DisplayName pgtype.Text `json:"display_name"`
AvatarUrl pgtype.Text `json:"avatar_url"`
DataRoot pgtype.Text `json:"data_root"`
LastLoginAt pgtype.Timestamptz `json:"last_login_at"`
IsActive bool `json:"is_active"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type UserChannelBinding struct {
-100
View File
@@ -208,15 +208,6 @@ func (q *Queries) DeleteModelByModelID(ctx context.Context, modelID string) erro
return err
}
const deleteModelVariant = `-- name: DeleteModelVariant :exec
DELETE FROM model_variants WHERE id = $1
`
func (q *Queries) DeleteModelVariant(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteModelVariant, id)
return err
}
const getLlmProviderByID = `-- name: GetLlmProviderByID :one
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers WHERE id = $1
`
@@ -299,25 +290,6 @@ func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model,
return i, err
}
const getModelVariantByID = `-- name: GetModelVariantByID :one
SELECT id, model_uuid, variant_id, weight, metadata, created_at, updated_at FROM model_variants WHERE id = $1
`
func (q *Queries) GetModelVariantByID(ctx context.Context, id pgtype.UUID) (ModelVariant, error) {
row := q.db.QueryRow(ctx, getModelVariantByID, id)
var i ModelVariant
err := row.Scan(
&i.ID,
&i.ModelUuid,
&i.VariantID,
&i.Weight,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listLlmProviders = `-- name: ListLlmProviders :many
SELECT id, name, client_type, base_url, api_key, metadata, created_at, updated_at FROM llm_providers
ORDER BY created_at DESC
@@ -421,40 +393,6 @@ func (q *Queries) ListModelVariantsByModelUUID(ctx context.Context, modelUuid pg
return items, nil
}
const listModelVariantsByVariantID = `-- name: ListModelVariantsByVariantID :many
SELECT id, model_uuid, variant_id, weight, metadata, created_at, updated_at FROM model_variants
WHERE variant_id = $1
ORDER BY created_at DESC
`
func (q *Queries) ListModelVariantsByVariantID(ctx context.Context, variantID string) ([]ModelVariant, error) {
rows, err := q.db.Query(ctx, listModelVariantsByVariantID, variantID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ModelVariant
for rows.Next() {
var i ModelVariant
if err := rows.Scan(
&i.ID,
&i.ModelUuid,
&i.VariantID,
&i.Weight,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listModels = `-- name: ListModels :many
SELECT id, model_id, name, llm_provider_id, dimensions, is_multimodal, type, created_at, updated_at FROM models
ORDER BY created_at DESC
@@ -777,41 +715,3 @@ func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByMod
)
return i, err
}
const updateModelVariant = `-- name: UpdateModelVariant :one
UPDATE model_variants
SET
variant_id = $1,
weight = $2,
metadata = $3,
updated_at = now()
WHERE id = $4
RETURNING id, model_uuid, variant_id, weight, metadata, created_at, updated_at
`
type UpdateModelVariantParams struct {
VariantID string `json:"variant_id"`
Weight int32 `json:"weight"`
Metadata []byte `json:"metadata"`
ID pgtype.UUID `json:"id"`
}
func (q *Queries) UpdateModelVariant(ctx context.Context, arg UpdateModelVariantParams) (ModelVariant, error) {
row := q.db.QueryRow(ctx, updateModelVariant,
arg.VariantID,
arg.Weight,
arg.Metadata,
arg.ID,
)
var i ModelVariant
err := row.Scan(
&i.ID,
&i.ModelUuid,
&i.VariantID,
&i.Weight,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
+2
View File
@@ -48,6 +48,8 @@ const getBotPreauthKey = `-- name: GetBotPreauthKey :one
SELECT id, bot_id, token, issued_by_user_id, expires_at, used_at, created_at
FROM bot_preauth_keys
WHERE token = $1
AND used_at IS NULL
AND (expires_at IS NULL OR expires_at > now())
LIMIT 1
`
+3 -80
View File
@@ -16,6 +16,9 @@ UPDATE bots
SET max_context_load_time = 1440,
language = 'auto',
allow_guest = false,
chat_model_id = NULL,
memory_model_id = NULL,
embedding_model_id = NULL,
updated_at = now()
WHERE id = $1
`
@@ -66,35 +69,6 @@ func (q *Queries) GetSettingsByBotID(ctx context.Context, id pgtype.UUID) (GetSe
return i, err
}
const getSettingsByUserID = `-- name: GetSettingsByUserID :one
SELECT id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
FROM users
WHERE id = $1
`
type GetSettingsByUserIDRow struct {
UserID pgtype.UUID `json:"user_id"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
MaxContextLoadTime int32 `json:"max_context_load_time"`
Language string `json:"language"`
}
func (q *Queries) GetSettingsByUserID(ctx context.Context, id pgtype.UUID) (GetSettingsByUserIDRow, error) {
row := q.db.QueryRow(ctx, getSettingsByUserID, id)
var i GetSettingsByUserIDRow
err := row.Scan(
&i.UserID,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
)
return i, err
}
const upsertBotSettings = `-- name: UpsertBotSettings :one
WITH updated AS (
UPDATE bots
@@ -164,54 +138,3 @@ func (q *Queries) UpsertBotSettings(ctx context.Context, arg UpsertBotSettingsPa
)
return i, err
}
const upsertUserSettings = `-- name: UpsertUserSettings :one
UPDATE users
SET chat_model_id = $2,
memory_model_id = $3,
embedding_model_id = $4,
max_context_load_time = $5,
language = $6,
updated_at = now()
WHERE id = $1
RETURNING id AS user_id, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language
`
type UpsertUserSettingsParams struct {
ID pgtype.UUID `json:"id"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
MaxContextLoadTime int32 `json:"max_context_load_time"`
Language string `json:"language"`
}
type UpsertUserSettingsRow struct {
UserID pgtype.UUID `json:"user_id"`
ChatModelID pgtype.Text `json:"chat_model_id"`
MemoryModelID pgtype.Text `json:"memory_model_id"`
EmbeddingModelID pgtype.Text `json:"embedding_model_id"`
MaxContextLoadTime int32 `json:"max_context_load_time"`
Language string `json:"language"`
}
func (q *Queries) UpsertUserSettings(ctx context.Context, arg UpsertUserSettingsParams) (UpsertUserSettingsRow, error) {
row := q.db.QueryRow(ctx, upsertUserSettings,
arg.ID,
arg.ChatModelID,
arg.MemoryModelID,
arg.EmbeddingModelID,
arg.MaxContextLoadTime,
arg.Language,
)
var i UpsertUserSettingsRow
err := row.Scan(
&i.UserID,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
)
return i, err
}
-31
View File
@@ -41,34 +41,3 @@ func (q *Queries) InsertSnapshot(ctx context.Context, arg InsertSnapshotParams)
)
return err
}
const listSnapshotsByContainerID = `-- name: ListSnapshotsByContainerID :many
SELECT id, container_id, parent_snapshot_id, snapshotter, digest, created_at FROM snapshots WHERE container_id = $1 ORDER BY created_at ASC
`
func (q *Queries) ListSnapshotsByContainerID(ctx context.Context, containerID string) ([]Snapshot, error) {
rows, err := q.db.Query(ctx, listSnapshotsByContainerID, containerID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Snapshot
for rows.Next() {
var i Snapshot
if err := rows.Scan(
&i.ID,
&i.ContainerID,
&i.ParentSnapshotID,
&i.Snapshotter,
&i.Digest,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+11 -135
View File
@@ -37,7 +37,7 @@ SET username = $1,
data_root = $8,
updated_at = now()
WHERE id = $9
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
`
type CreateAccountParams struct {
@@ -75,11 +75,6 @@ func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (U
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -91,7 +86,7 @@ func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (U
const createUser = `-- name: CreateUser :one
INSERT INTO users (is_active, metadata)
VALUES ($1, $2)
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
`
type CreateUserParams struct {
@@ -112,11 +107,6 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -126,7 +116,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e
}
const getAccountByIdentity = `-- name: GetAccountByIdentity :one
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 OR email = $1
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at FROM users WHERE username = $1 OR email = $1
`
func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text) (User, error) {
@@ -142,11 +132,6 @@ func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -156,7 +141,7 @@ func (q *Queries) GetAccountByIdentity(ctx context.Context, identity pgtype.Text
}
const getAccountByUserID = `-- name: GetAccountByUserID :one
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE id = $1
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at FROM users WHERE id = $1
`
func (q *Queries) GetAccountByUserID(ctx context.Context, userID pgtype.UUID) (User, error) {
@@ -172,41 +157,6 @@ func (q *Queries) GetAccountByUserID(ctx context.Context, userID pgtype.UUID) (U
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getAccountByUsername = `-- name: GetAccountByUsername :one
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users WHERE username = $1
`
func (q *Queries) GetAccountByUsername(ctx context.Context, username pgtype.Text) (User, error) {
row := q.db.QueryRow(ctx, getAccountByUsername, username)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.Email,
&i.PasswordHash,
&i.Role,
&i.DisplayName,
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -216,7 +166,7 @@ func (q *Queries) GetAccountByUsername(ctx context.Context, username pgtype.Text
}
const getUserByID = `-- name: GetUserByID :one
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
FROM users
WHERE id = $1
`
@@ -234,11 +184,6 @@ func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error)
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -248,7 +193,7 @@ func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error)
}
const listAccounts = `-- name: ListAccounts :many
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at FROM users
SELECT id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at FROM users
WHERE username IS NOT NULL
ORDER BY created_at DESC
`
@@ -272,11 +217,6 @@ func (q *Queries) ListAccounts(ctx context.Context) ([]User, error) {
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -300,7 +240,7 @@ SET role = $1::user_role,
is_active = $4,
updated_at = now()
WHERE id = $5
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
`
type UpdateAccountAdminParams struct {
@@ -330,11 +270,6 @@ func (q *Queries) UpdateAccountAdmin(ctx context.Context, arg UpdateAccountAdmin
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -348,7 +283,7 @@ UPDATE users
SET last_login_at = now(),
updated_at = now()
WHERE id = $1
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
`
func (q *Queries) UpdateAccountLastLogin(ctx context.Context, id pgtype.UUID) (User, error) {
@@ -364,11 +299,6 @@ func (q *Queries) UpdateAccountLastLogin(ctx context.Context, id pgtype.UUID) (U
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -382,7 +312,7 @@ UPDATE users
SET password_hash = $2,
updated_at = now()
WHERE id = $1
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
`
type UpdateAccountPasswordParams struct {
@@ -403,11 +333,6 @@ func (q *Queries) UpdateAccountPassword(ctx context.Context, arg UpdateAccountPa
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -423,7 +348,7 @@ SET display_name = $2,
is_active = $4,
updated_at = now()
WHERE id = $1
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
`
type UpdateAccountProfileParams struct {
@@ -451,50 +376,6 @@ func (q *Queries) UpdateAccountProfile(ctx context.Context, arg UpdateAccountPro
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateUserStatus = `-- name: UpdateUserStatus :one
UPDATE users
SET is_active = $2,
updated_at = now()
WHERE id = $1
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
`
type UpdateUserStatusParams struct {
ID pgtype.UUID `json:"id"`
IsActive bool `json:"is_active"`
}
func (q *Queries) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error) {
row := q.db.QueryRow(ctx, updateUserStatus, arg.ID, arg.IsActive)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.Email,
&i.PasswordHash,
&i.Role,
&i.DisplayName,
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
@@ -526,7 +407,7 @@ ON CONFLICT (username) DO UPDATE SET
is_active = EXCLUDED.is_active,
data_root = EXCLUDED.data_root,
updated_at = now()
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, chat_model_id, memory_model_id, embedding_model_id, max_context_load_time, language, is_active, metadata, created_at, updated_at
RETURNING id, username, email, password_hash, role, display_name, avatar_url, data_root, last_login_at, is_active, metadata, created_at, updated_at
`
type UpsertAccountByUsernameParams struct {
@@ -564,11 +445,6 @@ func (q *Queries) UpsertAccountByUsername(ctx context.Context, arg UpsertAccount
&i.AvatarUrl,
&i.DataRoot,
&i.LastLoginAt,
&i.ChatModelID,
&i.MemoryModelID,
&i.EmbeddingModelID,
&i.MaxContextLoadTime,
&i.Language,
&i.IsActive,
&i.Metadata,
&i.CreatedAt,
+7 -54
View File
@@ -3,16 +3,13 @@ package embeddings
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/models"
)
@@ -27,12 +24,11 @@ const (
)
type Request struct {
Type string
Provider string
Model string
Dimensions int
Input Input
ChannelIdentityID string
Type string
Provider string
Model string
Dimensions int
Input Input
}
type Input struct {
@@ -44,7 +40,7 @@ type Input struct {
type Usage struct {
InputTokens int
ImageTokens int
VideoTokens int
Duration int
}
type Result struct {
@@ -165,7 +161,7 @@ func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) {
Usage: Usage{
InputTokens: usage.InputTokens,
ImageTokens: usage.ImageTokens,
VideoTokens: usage.Duration,
Duration: usage.Duration,
},
}, nil
}
@@ -180,30 +176,6 @@ func (r *Resolver) selectEmbeddingModel(ctx context.Context, req Request) (model
return models.GetResponse{}, errors.New("models service not configured")
}
// If no model specified and no provider specified, try to get per-user embedding model.
if req.Model == "" && req.Provider == "" && strings.TrimSpace(req.ChannelIdentityID) != "" {
modelID, err := r.loadChannelIdentityEmbeddingModelID(ctx, req.ChannelIdentityID)
if err != nil {
return models.GetResponse{}, err
}
if modelID != "" {
selected, err := r.modelsService.GetByModelID(ctx, modelID)
if err != nil {
return models.GetResponse{}, fmt.Errorf("settings embedding model not found: %w", err)
}
if selected.Type != models.ModelTypeEmbedding {
return models.GetResponse{}, errors.New("settings embedding model is not an embedding model")
}
if req.Type == TypeMultimodal && !selected.IsMultimodal {
return models.GetResponse{}, errors.New("settings embedding model does not support multimodal")
}
if req.Type == TypeText && selected.IsMultimodal {
return models.GetResponse{}, errors.New("settings embedding model does not support text embeddings")
}
return selected, nil
}
}
var candidates []models.GetResponse
var err error
if req.Provider != "" {
@@ -257,22 +229,3 @@ func (r *Resolver) fetchProvider(ctx context.Context, providerID string) (sqlc.L
copy(pgID.Bytes[:], parsed[:])
return r.queries.GetLlmProviderByID(ctx, pgID)
}
func (r *Resolver) loadChannelIdentityEmbeddingModelID(ctx context.Context, channelIdentityID string) (string, error) {
if r.queries == nil {
return "", nil
}
pgChannelIdentityID, err := db.ParseUUID(channelIdentityID)
if err != nil {
return "", err
}
row, err := r.queries.GetSettingsByUserID(ctx, pgChannelIdentityID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", nil
}
return "", err
}
return strings.TrimSpace(row.EmbeddingModelID.String), nil
}
+1 -10
View File
@@ -10,9 +10,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bind"
"github.com/memohai/memoh/internal/identity"
)
// BindHandler manages channel identity bind code issuance via REST API.
@@ -80,12 +78,5 @@ func (h *BindHandler) Issue(c echo.Context) error {
}
func (h *BindHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return userID, nil
return RequireChannelIdentityID(c)
}
+1 -10
View File
@@ -7,9 +7,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/identity"
)
type ChannelHandler struct {
@@ -161,12 +159,5 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error {
}
func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
+43 -50
View File
@@ -18,18 +18,17 @@ import (
"github.com/containerd/containerd/v2/pkg/oci"
"github.com/containerd/errdefs"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/labstack/echo/v4"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/config"
ctr "github.com/memohai/memoh/internal/containerd"
"github.com/memohai/memoh/internal/db"
dbsqlc "github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/mcp"
"github.com/memohai/memoh/internal/policy"
)
@@ -166,7 +165,10 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
if dataRoot == "" {
dataRoot = config.DefaultDataRoot
}
dataRoot, _ = filepath.Abs(dataRoot)
dataRoot, err = filepath.Abs(dataRoot)
if err != nil {
h.logger.Warn("filepath.Abs failed", slog.Any("error", err))
}
dataMount := strings.TrimSpace(h.cfg.DataMount)
if dataMount == "" {
dataMount = config.DefaultDataMount
@@ -253,7 +255,9 @@ func (h *ContainerdHandler) CreateContainer(c echo.Context) error {
}
}
} else {
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true})
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}); err != nil {
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
}
h.logger.Error("mcp container network setup failed",
slog.String("container_id", containerID),
slog.Any("error", netErr),
@@ -303,7 +307,9 @@ func (h *ContainerdHandler) ensureContainerAndTask(ctx context.Context, containe
if tasks[0].Status == tasktypes.Status_RUNNING {
return nil
}
_ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true})
if err := h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}); err != nil {
h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err))
}
}
task, err := h.service.StartTask(ctx, containerID, &ctr.StartTaskOptions{
@@ -313,7 +319,9 @@ func (h *ContainerdHandler) ensureContainerAndTask(ctx context.Context, containe
return err
}
if err := ctr.SetupNetwork(ctx, task, containerID); err != nil {
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true})
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}); err != nil {
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
}
return err
}
return nil
@@ -324,10 +332,14 @@ func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (s
if h.queries != nil {
pgBotID, err := db.ParseUUID(botID)
if err == nil {
row, err := h.queries.GetContainerByBotID(ctx, pgBotID)
if err == nil && strings.TrimSpace(row.ContainerID) != "" {
row, dbErr := h.queries.GetContainerByBotID(ctx, pgBotID)
if dbErr == nil && strings.TrimSpace(row.ContainerID) != "" {
return row.ContainerID, nil
}
if dbErr != nil && !errors.Is(dbErr, pgx.ErrNoRows) {
h.logger.Warn("botContainerID: db lookup failed",
slog.String("bot_id", botID), slog.Any("error", dbErr))
}
}
}
// Fallback: search by containerd label
@@ -510,7 +522,9 @@ func (h *ContainerdHandler) StopContainer(c echo.Context) error {
}); err != nil && !errdefs.IsNotFound(err) {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
_ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true})
if err := h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}); err != nil {
h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err))
}
if h.queries != nil {
if pgBotID, parseErr := db.ParseUUID(botID); parseErr == nil {
if dbErr := h.queries.UpdateContainerStopped(ctx, pgBotID); dbErr != nil {
@@ -646,44 +660,11 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) {
}
func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
func (h *ContainerdHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) && h.policyService != nil {
allowGuest, policyErr := h.policyService.AllowGuest(ctx, botID)
if policyErr == nil && allowGuest {
bot, getErr := h.botService.Get(ctx, botID)
if getErr == nil {
return bot, nil
}
}
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
// SetupBotContainer creates and starts the MCP container for a bot.
@@ -701,7 +682,11 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string)
if dataRoot == "" {
dataRoot = config.DefaultDataRoot
}
dataRoot, _ = filepath.Abs(dataRoot)
if absRoot, absErr := filepath.Abs(dataRoot); absErr != nil {
h.logger.Warn("filepath.Abs failed", slog.Any("error", absErr))
} else {
dataRoot = absRoot
}
dataMount := strings.TrimSpace(h.cfg.DataMount)
if dataMount == "" {
dataMount = config.DefaultDataMount
@@ -786,7 +771,9 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string)
}
}
} else {
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true})
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{Force: true}); err != nil {
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
}
h.logger.Error("setup bot container: network setup failed",
slog.String("bot_id", botID),
slog.String("container_id", containerID),
@@ -830,15 +817,21 @@ func (h *ContainerdHandler) CleanupBotContainer(ctx context.Context, botID strin
if task, taskErr := h.service.GetTask(ctx, containerID); taskErr == nil {
h.logger.Info("CleanupBotContainer: removing network", slog.String("container_id", containerID))
_ = ctr.RemoveNetwork(ctx, task, containerID)
if err := ctr.RemoveNetwork(ctx, task, containerID); err != nil {
h.logger.Warn("cleanup: remove network failed", slog.String("container_id", containerID), slog.Any("error", err))
}
}
h.logger.Info("CleanupBotContainer: stopping task", slog.String("container_id", containerID))
_ = h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{
if err := h.service.StopTask(ctx, containerID, &ctr.StopTaskOptions{
Timeout: 5 * time.Second,
Force: true,
})
}); err != nil {
h.logger.Warn("cleanup: stop task failed", slog.String("container_id", containerID), slog.Any("error", err))
}
h.logger.Info("CleanupBotContainer: deleting task", slog.String("container_id", containerID))
_ = h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true})
if err := h.service.DeleteTask(ctx, containerID, &ctr.DeleteTaskOptions{Force: true}); err != nil {
h.logger.Warn("cleanup: delete task failed", slog.String("container_id", containerID), slog.Any("error", err))
}
h.logger.Info("CleanupBotContainer: deleting container", slog.String("container_id", containerID))
if err := h.service.DeleteContainer(ctx, containerID, &ctr.DeleteContainerOptions{
+2 -10
View File
@@ -8,7 +8,6 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/db/sqlc"
"github.com/memohai/memoh/internal/embeddings"
"github.com/memohai/memoh/internal/models"
@@ -48,7 +47,7 @@ type EmbeddingsResponse struct {
type EmbeddingsUsage struct {
InputTokens int `json:"input_tokens,omitempty"`
ImageTokens int `json:"image_tokens,omitempty"`
VideoTokens int `json:"video_tokens,omitempty"`
Duration int `json:"duration,omitempty"`
}
func NewEmbeddingsHandler(log *slog.Logger, modelsService *models.Service, queries *sqlc.Queries) *EmbeddingsHandler {
@@ -85,12 +84,6 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
req.Input.ImageURL = strings.TrimSpace(req.Input.ImageURL)
req.Input.VideoURL = strings.TrimSpace(req.Input.VideoURL)
userID := ""
if c.Get("user") != nil {
if value, err := auth.UserIDFromContext(c); err == nil {
userID = value
}
}
result, err := h.resolver.Embed(c.Request().Context(), embeddings.Request{
Type: req.Type,
Provider: req.Provider,
@@ -101,7 +94,6 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
ImageURL: req.Input.ImageURL,
VideoURL: req.Input.VideoURL,
},
ChannelIdentityID: userID,
})
if err != nil {
message := err.Error()
@@ -137,7 +129,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error {
Usage: EmbeddingsUsage{
InputTokens: result.Usage.InputTokens,
ImageTokens: result.Usage.ImageTokens,
VideoTokens: result.Usage.VideoTokens,
Duration: result.Usage.Duration,
},
})
}
-16
View File
@@ -53,22 +53,6 @@ func (h *ContainerdHandler) validateMCPContainer(ctx context.Context, containerI
return nil
}
func (h *ContainerdHandler) callMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) (map[string]any, error) {
session, err := h.getMCPSession(ctx, containerID)
if err != nil {
return nil, err
}
return session.call(ctx, req)
}
func (h *ContainerdHandler) notifyMCPServer(ctx context.Context, containerID string, req mcptools.JSONRPCRequest) error {
session, err := h.getMCPSession(ctx, containerID)
if err != nil {
return err
}
return session.notify(ctx, req)
}
type mcpSession struct {
stdin io.WriteCloser
stdout io.ReadCloser
+48
View File
@@ -0,0 +1,48 @@
package handlers
import (
"context"
"errors"
"net/http"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/identity"
)
// RequireChannelIdentityID extracts and validates the channel identity ID from the request context.
func RequireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
}
// AuthorizeBotAccess validates that the given identity has access to the specified bot.
func AuthorizeBotAccess(ctx context.Context, botService *bots.Service, accountService *accounts.Service, channelIdentityID, botID string, policy bots.AccessPolicy) (bots.Bot, error) {
if botService == nil || accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, policy)
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
}
+5 -30
View File
@@ -4,7 +4,6 @@ import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
@@ -13,12 +12,10 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/channel"
"github.com/memohai/memoh/internal/channel/adapters/local"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/identity"
)
// LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history.
@@ -103,7 +100,9 @@ func (h *LocalChannelHandler) StreamMessages(c echo.Context) error {
if err != nil {
continue
}
_, _ = writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data)))
if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))); err != nil {
return nil // client disconnected
}
writer.Flush()
flusher.Flush()
}
@@ -185,33 +184,9 @@ func (h *LocalChannelHandler) ensureBotParticipant(ctx context.Context, botID, c
}
func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
func (h *LocalChannelHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
}
+2 -28
View File
@@ -11,9 +11,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/mcp"
)
@@ -218,33 +216,9 @@ func (h *MCPHandler) Delete(c echo.Context) error {
}
func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return userID, nil
return RequireChannelIdentityID(c)
}
func (h *MCPHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
+2 -11
View File
@@ -10,9 +10,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/memory"
)
@@ -168,7 +166,7 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error {
for _, scope := range scopes {
filters := buildNamespaceFilters(scope.Namespace, scope.ScopeID, payload.Filters)
if botID != "" {
filters["botId"] = botID
filters["bot_id"] = botID
}
req := memory.SearchRequest{
Query: payload.Query,
@@ -378,12 +376,5 @@ func (h *MemoryHandler) requireChatParticipant(ctx context.Context, chatID, chan
}
func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
+15 -54
View File
@@ -15,12 +15,10 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/channel/identities"
"github.com/memohai/memoh/internal/conversation"
"github.com/memohai/memoh/internal/conversation/flow"
"github.com/memohai/memoh/internal/identity"
messagepkg "github.com/memohai/memoh/internal/message"
messageevent "github.com/memohai/memoh/internal/message/event"
)
@@ -85,7 +83,7 @@ func (h *MessageHandler) SendMessage(c echo.Context) error {
return err
}
var req flow.ChatRequest
var req conversation.ChatRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
@@ -132,7 +130,7 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error {
return err
}
var req flow.ChatRequest
var req conversation.ChatRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
@@ -200,10 +198,12 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error {
h.logger.Error("conversation stream failed", slog.Any("error", err))
if processingState == "started" {
processingState = "failed"
_ = writeSSEJSON(writer, flusher, map[string]string{
"type": "processing_failed",
"error": err.Error(),
})
if writeErr := writeSSEJSON(writer, flusher, map[string]string{
"type": "processing_failed",
"error": err.Error(),
}); writeErr != nil {
h.logger.Warn("write SSE processing_failed event failed", slog.Any("error", writeErr))
}
}
errData := map[string]string{
"type": "error",
@@ -459,7 +459,7 @@ func (h *MessageHandler) DeleteMessages(c echo.Context) error {
// resolveWebChannelIdentity resolves (web, user_id) to a channel identity and sets req.SourceChannelIdentityID.
// Web uses user_id as the channel subject id (like Feishu open_id); the resolved ci has display_name and is linked to the user.
// Returns the channel_identity_id to use for the rest of the flow, or the original userID if resolution is skipped/fails.
func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID string, req *flow.ChatRequest) string {
func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID string, req *conversation.ChatRequest) string {
if strings.TrimSpace(req.CurrentChannel) != "web" || h.channelIdentitySvc == nil || strings.TrimSpace(userID) == "" {
return userID
}
@@ -476,62 +476,23 @@ func (h *MessageHandler) resolveWebChannelIdentity(ctx context.Context, userID s
if err != nil {
return userID
}
_ = h.channelIdentitySvc.LinkChannelIdentityToUser(ctx, ci.ID, userID)
if err := h.channelIdentitySvc.LinkChannelIdentityToUser(ctx, ci.ID, userID); err != nil {
h.logger.Warn("link channel identity to user failed", slog.Any("error", err))
}
req.SourceChannelIdentityID = ci.ID
return ci.ID
}
func (h *MessageHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
func (h *MessageHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: true})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: true})
}
func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot management access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
func (h *MessageHandler) requireParticipant(ctx context.Context, conversationID, channelIdentityID string) error {
+5 -72
View File
@@ -4,26 +4,21 @@ import (
"log/slog"
"net/http"
"net/url"
"strings"
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/models"
"github.com/memohai/memoh/internal/settings"
)
type ModelsHandler struct {
service *models.Service
settingsService *settings.Service
logger *slog.Logger
service *models.Service
logger *slog.Logger
}
func NewModelsHandler(log *slog.Logger, service *models.Service, settingsService *settings.Service) *ModelsHandler {
func NewModelsHandler(log *slog.Logger, service *models.Service) *ModelsHandler {
return &ModelsHandler{
service: service,
settingsService: settingsService,
logger: log.With(slog.String("handler", "models")),
service: service,
logger: log.With(slog.String("handler", "models")),
}
}
@@ -33,7 +28,6 @@ func (h *ModelsHandler) Register(e *echo.Echo) {
group.GET("", h.List)
group.GET("/:id", h.GetByID)
group.GET("/model/:modelId", h.GetByModelID)
group.POST("/enable", h.Enable)
group.PUT("/:id", h.UpdateByID)
group.PUT("/model/:modelId", h.UpdateByModelID)
group.DELETE("/:id", h.DeleteByID)
@@ -145,67 +139,6 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error {
return c.JSON(http.StatusOK, resp)
}
type EnableModelRequest struct {
As string `json:"as"`
ModelID string `json:"model_id"`
}
// Enable godoc
// @Summary Enable model for chat/memory/embedding
// @Description Update the current user's settings to use the selected model
// @Tags models
// @Param payload body handlers.EnableModelRequest true "Enable model payload"
// @Success 200 {object} settings.Settings
// @Failure 400 {object} ErrorResponse
// @Failure 404 {object} ErrorResponse
// @Failure 500 {object} ErrorResponse
// @Router /models/enable [post]
func (h *ModelsHandler) Enable(c echo.Context) error {
if h.settingsService == nil {
return echo.NewHTTPError(http.StatusInternalServerError, "settings service not configured")
}
userID, err := auth.UserIDFromContext(c)
if err != nil {
return err
}
var req EnableModelRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
req.As = strings.ToLower(strings.TrimSpace(req.As))
req.ModelID = strings.TrimSpace(req.ModelID)
if req.As == "" || req.ModelID == "" {
return echo.NewHTTPError(http.StatusBadRequest, "as and model_id are required")
}
if req.As != "chat" && req.As != "memory" && req.As != "embedding" {
return echo.NewHTTPError(http.StatusBadRequest, "as must be one of chat, memory, embedding")
}
model, err := h.service.GetByModelID(c.Request().Context(), req.ModelID)
if err != nil {
return echo.NewHTTPError(http.StatusNotFound, err.Error())
}
if req.As == "embedding" && model.Type != models.ModelTypeEmbedding {
return echo.NewHTTPError(http.StatusBadRequest, "model is not an embedding model")
}
if (req.As == "chat" || req.As == "memory") && model.Type != models.ModelTypeChat {
return echo.NewHTTPError(http.StatusBadRequest, "model is not a chat model")
}
upsert := settings.UpsertRequest{}
switch req.As {
case "chat":
upsert.ChatModelID = req.ModelID
case "memory":
upsert.MemoryModelID = req.ModelID
case "embedding":
upsert.EmbeddingModelID = req.ModelID
}
resp, err := h.settingsService.Upsert(c.Request().Context(), userID, upsert)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return c.JSON(http.StatusOK, resp)
}
// UpdateByID godoc
// @Summary Update model by internal ID
// @Description Update a model configuration by its internal UUID
+2 -29
View File
@@ -2,7 +2,6 @@ package handlers
import (
"context"
"errors"
"net/http"
"strings"
"time"
@@ -10,9 +9,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/preauth"
)
@@ -67,33 +64,9 @@ func (h *PreauthHandler) Issue(c echo.Context) error {
}
func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return userID, nil
return RequireChannelIdentityID(c)
}
func (h *PreauthHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
+2 -29
View File
@@ -2,7 +2,6 @@ package handlers
import (
"context"
"errors"
"log/slog"
"net/http"
"strings"
@@ -10,9 +9,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/schedule"
)
@@ -219,33 +216,9 @@ func (h *ScheduleHandler) Delete(c echo.Context) error {
}
func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) {
userID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(userID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return userID, nil
return RequireChannelIdentityID(c)
}
func (h *ScheduleHandler) authorizeBotAccess(ctx context.Context, userID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, userID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, userID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, userID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
+2 -28
View File
@@ -10,9 +10,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/settings"
)
@@ -130,33 +128,9 @@ func (h *SettingsHandler) Delete(c echo.Context) error {
}
func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
func (h *SettingsHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
+2 -29
View File
@@ -2,7 +2,6 @@ package handlers
import (
"context"
"errors"
"log/slog"
"net/http"
"strings"
@@ -10,9 +9,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/memohai/memoh/internal/accounts"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/bots"
"github.com/memohai/memoh/internal/identity"
"github.com/memohai/memoh/internal/subagent"
)
@@ -433,33 +430,9 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error {
}
func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
func (h *SubagentHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
if h.botService == nil || h.accountService == nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, "bot services not configured")
}
isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
+3 -24
View File
@@ -32,7 +32,7 @@ type UsersHandler struct {
}
type listMyIdentitiesResponse struct {
UserID string `json:"user_id"`
UserID string `json:"user_id"`
Items []identities.ChannelIdentity `json:"items"`
}
@@ -943,30 +943,9 @@ func (h *UsersHandler) SendBotMessageSession(c echo.Context) error {
}
func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) {
isAdmin, err := h.service.IsAdmin(ctx, channelIdentityID)
if err != nil {
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
bot, err := h.botService.AuthorizeAccess(ctx, channelIdentityID, botID, isAdmin, bots.AccessPolicy{AllowPublicMember: false})
if err != nil {
if errors.Is(err, bots.ErrBotNotFound) {
return bots.Bot{}, echo.NewHTTPError(http.StatusNotFound, "bot not found")
}
if errors.Is(err, bots.ErrBotAccessDenied) {
return bots.Bot{}, echo.NewHTTPError(http.StatusForbidden, "bot access denied")
}
return bots.Bot{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return bot, nil
return AuthorizeBotAccess(ctx, h.botService, h.service, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false})
}
func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) {
channelIdentityID, err := auth.UserIDFromContext(c)
if err != nil {
return "", err
}
if err := identity.ValidateChannelIdentityID(channelIdentityID); err != nil {
return "", echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
return channelIdentityID, nil
return RequireChannelIdentityID(c)
}
+2 -2
View File
@@ -19,7 +19,7 @@ type Connection struct {
Name string `json:"name"`
Type string `json:"type"`
Config map[string]any `json:"config"`
Active bool `json:"active"`
Active bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -29,7 +29,7 @@ type UpsertRequest struct {
Name string `json:"name"`
Type string `json:"type,omitempty"`
Config map[string]any `json:"config"`
Active *bool `json:"active,omitempty"`
Active *bool `json:"is_active,omitempty"`
}
// ListResponse wraps MCP connection list responses.
+9 -11
View File
@@ -175,7 +175,9 @@ func (m *Manager) Start(ctx context.Context, botID string) error {
return err
}
if err := ctr.SetupNetwork(ctx, task, m.containerID(botID)); err != nil {
_ = m.service.StopTask(ctx, m.containerID(botID), &ctr.StopTaskOptions{Force: true})
if stopErr := m.service.StopTask(ctx, m.containerID(botID), &ctr.StopTaskOptions{Force: true}); stopErr != nil {
m.logger.Warn("cleanup: stop task failed", slog.String("container_id", m.containerID(botID)), slog.Any("error", stopErr))
}
return err
}
return nil
@@ -197,9 +199,13 @@ func (m *Manager) Delete(ctx context.Context, botID string) error {
}
if task, taskErr := m.service.GetTask(ctx, m.containerID(botID)); taskErr == nil {
_ = ctr.RemoveNetwork(ctx, task, m.containerID(botID))
if err := ctr.RemoveNetwork(ctx, task, m.containerID(botID)); err != nil {
m.logger.Warn("cleanup: remove network failed", slog.String("container_id", m.containerID(botID)), slog.Any("error", err))
}
}
if err := m.service.DeleteTask(ctx, m.containerID(botID), &ctr.DeleteTaskOptions{Force: true}); err != nil {
m.logger.Warn("cleanup: delete task failed", slog.String("container_id", m.containerID(botID)), slog.Any("error", err))
}
_ = m.service.DeleteTask(ctx, m.containerID(botID), &ctr.DeleteTaskOptions{Force: true})
return m.service.DeleteContainer(ctx, m.containerID(botID), &ctr.DeleteContainerOptions{
CleanupSnapshot: true,
})
@@ -347,14 +353,6 @@ func (m *Manager) execWithCaptureContainerd(ctx context.Context, req ExecRequest
}, nil
}
// sshShellQuote wraps a string in single quotes for safe SSH transport.
func sshShellQuote(s string) string {
if s == "" {
return "''"
}
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
}
// DataDir returns the host data directory for a bot.
func (m *Manager) DataDir(botID string) (string, error) {
if err := validateBotID(botID); err != nil {
+1 -1
View File
@@ -134,7 +134,7 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex
Filters: map[string]any{
"namespace": sharedMemoryNamespace,
"scopeId": botID,
"botId": botID,
"bot_id": botID,
},
})
if err != nil {
@@ -23,15 +23,15 @@ func (f *fakeSearcher) Search(ctx context.Context, req memory.SearchRequest) (me
}
type fakeChatAccessor struct {
chat conversation.Chat
chat conversation.Conversation
getErr error
participant bool
participantErr error
}
func (f *fakeChatAccessor) Get(ctx context.Context, conversationID string) (conversation.Chat, error) {
func (f *fakeChatAccessor) Get(ctx context.Context, conversationID string) (conversation.Conversation, error) {
if f.getErr != nil {
return conversation.Chat{}, f.getErr
return conversation.Conversation{}, f.getErr
}
return f.chat, nil
}
@@ -43,8 +43,8 @@ func (f *fakeChatAccessor) IsParticipant(ctx context.Context, conversationID, ch
return f.participant, nil
}
func (f *fakeChatAccessor) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (conversation.ChatReadAccess, error) {
return conversation.ChatReadAccess{}, nil
func (f *fakeChatAccessor) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (conversation.ConversationReadAccess, error) {
return conversation.ConversationReadAccess{}, nil
}
type fakeAdminChecker struct {
@@ -180,7 +180,7 @@ func TestExecutor_CallTool_ChatNotFound(t *testing.T) {
func TestExecutor_CallTool_BotMismatch(t *testing.T) {
accessor := &fakeChatAccessor{
chat: conversation.Chat{BotID: "other-bot", ID: "c1"},
chat: conversation.Conversation{BotID: "other-bot", ID: "c1"},
}
searcher := &fakeSearcher{}
exec := NewExecutor(nil, searcher, accessor, nil)
@@ -196,7 +196,7 @@ func TestExecutor_CallTool_BotMismatch(t *testing.T) {
func TestExecutor_CallTool_NotParticipant(t *testing.T) {
accessor := &fakeChatAccessor{
chat: conversation.Chat{BotID: "bot1", ID: "c1"},
chat: conversation.Conversation{BotID: "bot1", ID: "c1"},
participant: false,
}
searcher := &fakeSearcher{}
@@ -216,7 +216,7 @@ func TestExecutor_CallTool_AdminBypass(t *testing.T) {
resp: memory.SearchResponse{Results: []memory.MemoryItem{{ID: "id1", Memory: "m1", Score: 0.8}}},
}
accessor := &fakeChatAccessor{
chat: conversation.Chat{BotID: "bot1", ID: "c1"},
chat: conversation.Conversation{BotID: "bot1", ID: "c1"},
participant: false,
}
admin := &fakeAdminChecker{admin: true}
+4 -1
View File
@@ -91,7 +91,10 @@ func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext,
route, ok := s.getRoute(botID, toolName)
if !ok {
_, _ = s.ListTools(ctx, session)
// Refresh route cache; result intentionally discarded.
if _, err := s.ListTools(ctx, session); err != nil {
s.logger.Warn("federation: refresh tools cache failed", slog.Any("error", err))
}
route, ok = s.getRoute(botID, toolName)
if !ok {
return nil, mcpgw.ErrToolNotFound
+2 -1
View File
@@ -2,6 +2,7 @@ package mcp
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
@@ -106,7 +107,7 @@ func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionCo
}
result, err := executor.CallTool(ctx, session, toolName, arguments)
if err != nil {
if err == ErrToolNotFound {
if errors.Is(err, ErrToolNotFound) {
return BuildToolErrorResult("tool not found: " + toolName), nil
}
return BuildToolErrorResult(err.Error()), nil
+3 -1
View File
@@ -349,7 +349,9 @@ func parseJSONMap(data []byte) map[string]any {
return nil
}
var m map[string]any
_ = json.Unmarshal(data, &m)
if err := json.Unmarshal(data, &m); err != nil {
slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err))
}
return m
}
+1 -1
View File
@@ -314,7 +314,7 @@ func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64,
func convertToGetResponse(dbModel sqlc.Model) GetResponse {
resp := GetResponse{
ModelId: dbModel.ModelID,
ModelID: dbModel.ModelID,
Model: Model{
ModelID: dbModel.ModelID,
IsMultimodal: dbModel.IsMultimodal,
+1 -72
View File
@@ -83,7 +83,7 @@ func ExampleService_UpdateByModelID() {
// if err != nil {
// // handle error
// }
// fmt.Printf("Updated model: %s\n", resp.ModelId)
// fmt.Printf("Updated model: %s\n", resp.ModelID)
}
func ExampleService_DeleteByModelID() {
@@ -202,74 +202,3 @@ func TestModelTypes(t *testing.T) {
assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope)
})
}
// Integration test example (requires actual database)
// func TestService_Integration(t *testing.T) {
// if testing.Short() {
// t.Skip("Skipping integration test")
// }
//
// ctx := context.Background()
//
// // Setup database connection
// pool, err := db.Open(ctx, config.PostgresConfig{
// Host: "localhost",
// Port: 5432,
// User: "test",
// Password: "test",
// Database: "test_db",
// SSLMode: "disable",
// })
// require.NoError(t, err)
// defer pool.Close()
//
// queries := sqlc.New(pool)
// service := models.NewService(queries)
//
// // Test Create
// createReq := models.AddRequest{
// ModelID: "test-gpt-4",
// Name: "Test GPT-4",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-test",
// ClientType: models.ClientTypeOpenAI,
// Type: models.ModelTypeChat,
// }
// createResp, err := service.Create(ctx, createReq)
// require.NoError(t, err)
// assert.NotEmpty(t, createResp.ID)
// assert.Equal(t, "test-gpt-4", createResp.ModelID)
//
// // Test GetByModelID
// getResp, err := service.GetByModelID(ctx, "test-gpt-4")
// require.NoError(t, err)
// assert.Equal(t, "test-gpt-4", getResp.ModelID)
// assert.Equal(t, "Test GPT-4", getResp.Name)
//
// // Test List
// models, err := service.List(ctx)
// require.NoError(t, err)
// assert.NotEmpty(t, models)
//
// // Test Update
// updateReq := models.UpdateRequest{
// ModelID: "test-gpt-4",
// Name: "Updated GPT-4",
// BaseURL: "https://api.openai.com/v1",
// APIKey: "sk-test-updated",
// ClientType: models.ClientTypeOpenAI,
// Type: models.ModelTypeChat,
// }
// updateResp, err := service.UpdateByModelID(ctx, "test-gpt-4", updateReq)
// require.NoError(t, err)
// assert.Equal(t, "Updated GPT-4", updateResp.Name)
//
// // Test Count
// count, err := service.Count(ctx)
// require.NoError(t, err)
// assert.Greater(t, count, int64(0))
//
// // Test Delete
// err = service.DeleteByModelID(ctx, "test-gpt-4")
// require.NoError(t, err)
// }
+1 -1
View File
@@ -75,7 +75,7 @@ type GetRequest struct {
}
type GetResponse struct {
ModelId string `json:"model_id"`
ModelID string `json:"model_id"`
Model
}
+7 -7
View File
@@ -90,13 +90,13 @@ func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) {
func normalizeKey(row sqlc.BotPreauthKey) Key {
return Key{
ID: row.ID.String(),
BotID: row.BotID.String(),
Token: strings.TrimSpace(row.Token),
IssuedByChannelIdentityID: row.IssuedByUserID.String(),
ExpiresAt: timeFromPg(row.ExpiresAt),
UsedAt: timeFromPg(row.UsedAt),
CreatedAt: timeFromPg(row.CreatedAt),
ID: row.ID.String(),
BotID: row.BotID.String(),
Token: strings.TrimSpace(row.Token),
IssuedByUserID: row.IssuedByUserID.String(),
ExpiresAt: timeFromPg(row.ExpiresAt),
UsedAt: timeFromPg(row.UsedAt),
CreatedAt: timeFromPg(row.CreatedAt),
}
}
+7 -7
View File
@@ -4,11 +4,11 @@ import "time"
// Key represents a bot pre-authorization key.
type Key struct {
ID string
BotID string
Token string
IssuedByChannelIdentityID string
ExpiresAt time.Time
UsedAt time.Time
CreatedAt time.Time
ID string
BotID string
Token string
IssuedByUserID string
ExpiresAt time.Time
UsedAt time.Time
CreatedAt time.Time
}
+3 -1
View File
@@ -211,7 +211,9 @@ func (s *Service) CountByClientType(ctx context.Context, clientType ClientType)
func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse {
var metadata map[string]any
if len(provider.Metadata) > 0 {
_ = json.Unmarshal(provider.Metadata, &metadata)
if err := json.Unmarshal(provider.Metadata, &metadata); err != nil {
slog.Warn("provider metadata unmarshal failed", slog.String("id", provider.ID.String()), slog.Any("error", err))
}
}
// Mask API key (show only first 8 characters)
+10 -5
View File
@@ -187,7 +187,9 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sch
if err != nil {
return Schedule{}, err
}
s.rescheduleJob(updated)
if err := s.rescheduleJob(updated); err != nil {
return Schedule{}, fmt.Errorf("reschedule job: %w", err)
}
return toSchedule(updated), nil
}
@@ -287,7 +289,9 @@ func (s *Service) scheduleJob(schedule sqlc.Schedule) error {
return fmt.Errorf("schedule id missing")
}
job := func() {
_ = s.runSchedule(context.Background(), toSchedule(schedule))
if err := s.runSchedule(context.Background(), toSchedule(schedule)); err != nil {
s.logger.Error("scheduled job failed", slog.String("schedule_id", schedule.ID.String()), slog.Any("error", err))
}
}
entryID, err := s.cron.AddFunc(schedule.Pattern, job)
if err != nil {
@@ -299,15 +303,16 @@ func (s *Service) scheduleJob(schedule sqlc.Schedule) error {
return nil
}
func (s *Service) rescheduleJob(schedule sqlc.Schedule) {
func (s *Service) rescheduleJob(schedule sqlc.Schedule) error {
id := schedule.ID.String()
if id == "" {
return
return nil
}
s.removeJob(id)
if schedule.Enabled {
_ = s.scheduleJob(schedule)
return s.scheduleJob(schedule)
}
return nil
}
func (s *Service) removeJob(id string) {
-1
View File
@@ -11,7 +11,6 @@ type TriggerPayload struct {
MaxCalls *int
Command string
OwnerUserID string
ChatID string
}
// Triggerer triggers schedule execution for chat-related jobs.
-96
View File
@@ -7,7 +7,6 @@ import (
"log/slog"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/memohai/memoh/internal/db"
@@ -28,83 +27,6 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service {
}
}
// Get returns user-level settings.
func (s *Service) Get(ctx context.Context, userID string) (Settings, error) {
pgID, err := db.ParseUUID(userID)
if err != nil {
return Settings{}, err
}
row, err := s.queries.GetSettingsByUserID(ctx, pgID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return Settings{
ChatModelID: "",
MemoryModelID: "",
EmbeddingModelID: "",
MaxContextLoadTime: DefaultMaxContextLoadTime,
Language: DefaultLanguage,
}, nil
}
return Settings{}, err
}
return normalizeUserSetting(row), nil
}
// Upsert creates or updates user-level settings.
func (s *Service) Upsert(ctx context.Context, userID string, req UpsertRequest) (Settings, error) {
if s.queries == nil {
return Settings{}, fmt.Errorf("settings queries not configured")
}
pgID, err := db.ParseUUID(userID)
if err != nil {
return Settings{}, err
}
current := Settings{
ChatModelID: "",
MemoryModelID: "",
EmbeddingModelID: "",
MaxContextLoadTime: DefaultMaxContextLoadTime,
Language: DefaultLanguage,
}
existing, err := s.queries.GetSettingsByUserID(ctx, pgID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return Settings{}, err
}
if err == nil {
current = normalizeUserSetting(existing)
}
if value := strings.TrimSpace(req.ChatModelID); value != "" {
current.ChatModelID = value
}
if value := strings.TrimSpace(req.MemoryModelID); value != "" {
current.MemoryModelID = value
}
if value := strings.TrimSpace(req.EmbeddingModelID); value != "" {
current.EmbeddingModelID = value
}
if req.MaxContextLoadTime != nil && *req.MaxContextLoadTime > 0 {
current.MaxContextLoadTime = *req.MaxContextLoadTime
}
if strings.TrimSpace(req.Language) != "" {
current.Language = strings.TrimSpace(req.Language)
}
_, err = s.queries.UpsertUserSettings(ctx, sqlc.UpsertUserSettingsParams{
ID: pgID,
ChatModelID: pgtype.Text{String: current.ChatModelID, Valid: current.ChatModelID != ""},
MemoryModelID: pgtype.Text{String: current.MemoryModelID, Valid: current.MemoryModelID != ""},
EmbeddingModelID: pgtype.Text{String: current.EmbeddingModelID, Valid: current.EmbeddingModelID != ""},
MaxContextLoadTime: int32(current.MaxContextLoadTime),
Language: current.Language,
})
if err != nil {
return Settings{}, err
}
return current, nil
}
func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) {
pgID, err := db.ParseUUID(botID)
if err != nil {
@@ -198,23 +120,6 @@ func (s *Service) Delete(ctx context.Context, botID string) error {
return s.queries.DeleteSettingsByBotID(ctx, pgID)
}
func normalizeUserSetting(row sqlc.GetSettingsByUserIDRow) Settings {
settings := Settings{
ChatModelID: strings.TrimSpace(row.ChatModelID.String),
MemoryModelID: strings.TrimSpace(row.MemoryModelID.String),
EmbeddingModelID: strings.TrimSpace(row.EmbeddingModelID.String),
MaxContextLoadTime: int(row.MaxContextLoadTime),
Language: strings.TrimSpace(row.Language),
}
if settings.MaxContextLoadTime <= 0 {
settings.MaxContextLoadTime = DefaultMaxContextLoadTime
}
if settings.Language == "" {
settings.Language = DefaultLanguage
}
return settings
}
func normalizeBotSetting(maxContextLoadTime int32, language string, allowGuest bool) Settings {
settings := Settings{
MaxContextLoadTime: int(maxContextLoadTime),
@@ -277,4 +182,3 @@ func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.
}
return row.ID, nil
}
+3 -66
View File
@@ -62,13 +62,6 @@ type ScheduleListResponse = {
items: Schedule[]
}
type Settings = {
chat_model_id: string
memory_model_id: string
embedding_model_id: string
max_context_load_time: number
language: string
}
type Bot = {
id: string
@@ -280,18 +273,6 @@ configCmd.action(async () => {
const config = readConfig()
console.log(`host = "${config.host}"`)
console.log(`port = ${config.port}`)
const token = readToken()
if (!token?.access_token) return
try {
const settings = await apiRequest<Settings>('/settings', {}, token)
console.log(`chat_model_id = "${settings.chat_model_id || ''}"`)
console.log(`memory_model_id = "${settings.memory_model_id || ''}"`)
console.log(`embedding_model_id = "${settings.embedding_model_id || ''}"`)
console.log(`max_context_load_time = ${settings.max_context_load_time}`)
console.log(`language = "${settings.language}"`)
} catch (err: unknown) {
console.log(chalk.yellow(`Unable to load settings: ${getErrorMessage(err)}`))
}
})
configCmd
@@ -299,33 +280,12 @@ configCmd
.description('Update config')
.option('--host <host>')
.option('--port <port>')
.option('--chat_model_id <model_id>')
.option('--memory_model_id <model_id>')
.option('--embedding_model_id <model_id>')
.option('--max_context_load_time <minutes>')
.option('--language <language>')
.action(async (opts) => {
const current = readConfig()
let host = opts.host
let port = opts.port ? Number.parseInt(opts.port, 10) : undefined
let maxContextLoadTime: number | undefined
if (opts.max_context_load_time !== undefined) {
const parsed = Number.parseInt(opts.max_context_load_time, 10)
if (Number.isNaN(parsed) || parsed <= 0) {
console.log(chalk.red('max_context_load_time must be a positive integer.'))
process.exit(1)
}
maxContextLoadTime = parsed
}
let language = opts.language
const hasSettingsInput = opts.max_context_load_time !== undefined
|| opts.language !== undefined
|| opts.chat_model_id !== undefined
|| opts.memory_model_id !== undefined
|| opts.embedding_model_id !== undefined
const hasConfigInput = Boolean(host || port)
if (!hasConfigInput && !hasSettingsInput) {
if (!host && !port) {
const answers = await inquirer.prompt([
{ type: 'input', name: 'host', message: 'Host:', default: current.host },
{ type: 'input', name: 'port', message: 'Port:', default: current.port },
@@ -337,31 +297,8 @@ configCmd
if (host) current.host = host
if (port && !Number.isNaN(port)) current.port = port
if (host || (port && !Number.isNaN(port))) {
writeConfig(current)
console.log(chalk.green('Config updated'))
}
if (hasSettingsInput) {
if (language) {
language = String(language).trim()
}
const payload: Partial<Settings> = {}
if (opts.chat_model_id) payload.chat_model_id = String(opts.chat_model_id).trim()
if (opts.memory_model_id) payload.memory_model_id = String(opts.memory_model_id).trim()
if (opts.embedding_model_id) payload.embedding_model_id = String(opts.embedding_model_id).trim()
if (maxContextLoadTime !== undefined) payload.max_context_load_time = maxContextLoadTime
if (language) payload.language = language
const token = ensureAuth()
const spinner = ora('Updating settings...').start()
try {
await apiRequest('/settings', { method: 'PUT', body: JSON.stringify(payload) }, token)
spinner.succeed('Settings updated')
} catch (err: unknown) {
spinner.fail(getErrorMessage(err) || 'Failed to update settings')
process.exit(1)
}
}
writeConfig(current)
console.log(chalk.green('Config updated'))
})
const provider = program.command('provider').description('Provider management')
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
-41
View File
@@ -374,11 +374,6 @@ export type HandlersEmbeddingsUsage = {
video_tokens?: number;
};
export type HandlersEnableModelRequest = {
as?: string;
model_id?: string;
};
export type HandlersErrorResponse = {
message?: string;
};
@@ -2923,42 +2918,6 @@ export type GetModelsCountResponses = {
export type GetModelsCountResponse = GetModelsCountResponses[keyof GetModelsCountResponses];
export type PostModelsEnableData = {
/**
* Enable model payload
*/
body: HandlersEnableModelRequest;
path?: never;
query?: never;
url: '/models/enable';
};
export type PostModelsEnableErrors = {
/**
* Bad Request
*/
400: HandlersErrorResponse;
/**
* Not Found
*/
404: HandlersErrorResponse;
/**
* Internal Server Error
*/
500: HandlersErrorResponse;
};
export type PostModelsEnableError = PostModelsEnableErrors[keyof PostModelsEnableErrors];
export type PostModelsEnableResponses = {
/**
* OK
*/
200: SettingsSettings;
};
export type PostModelsEnableResponse = PostModelsEnableResponses[keyof PostModelsEnableResponses];
export type DeleteModelsModelByModelIdData = {
body?: never;
path: {
-57
View File
@@ -2640,52 +2640,6 @@ const docTemplate = `{
}
}
},
"/models/enable": {
"post": {
"description": "Update the current user's settings to use the selected model",
"tags": [
"models"
],
"summary": "Enable model for chat/memory/embedding",
"parameters": [
{
"description": "Enable model payload",
"name": "payload",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.EnableModelRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/settings.Settings"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/models/model/{modelId}": {
"get": {
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
@@ -4754,17 +4708,6 @@ const docTemplate = `{
}
}
},
"handlers.EnableModelRequest": {
"type": "object",
"properties": {
"as": {
"type": "string"
},
"model_id": {
"type": "string"
}
}
},
"handlers.ErrorResponse": {
"type": "object",
"properties": {
-57
View File
@@ -2631,52 +2631,6 @@
}
}
},
"/models/enable": {
"post": {
"description": "Update the current user's settings to use the selected model",
"tags": [
"models"
],
"summary": "Enable model for chat/memory/embedding",
"parameters": [
{
"description": "Enable model payload",
"name": "payload",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/handlers.EnableModelRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/settings.Settings"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"404": {
"description": "Not Found",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/handlers.ErrorResponse"
}
}
}
}
},
"/models/model/{modelId}": {
"get": {
"description": "Get a model configuration by its model_id field (e.g., gpt-4)",
@@ -4745,17 +4699,6 @@
}
}
},
"handlers.EnableModelRequest": {
"type": "object",
"properties": {
"as": {
"type": "string"
},
"model_id": {
"type": "string"
}
}
},
"handlers.ErrorResponse": {
"type": "object",
"properties": {
-37
View File
@@ -629,13 +629,6 @@ definitions:
video_tokens:
type: integer
type: object
handlers.EnableModelRequest:
properties:
as:
type: string
model_id:
type: string
type: object
handlers.ErrorResponse:
properties:
message:
@@ -3023,36 +3016,6 @@ paths:
summary: Get model count
tags:
- models
/models/enable:
post:
description: Update the current user's settings to use the selected model
parameters:
- description: Enable model payload
in: body
name: payload
required: true
schema:
$ref: '#/definitions/handlers.EnableModelRequest'
responses:
"200":
description: OK
schema:
$ref: '#/definitions/settings.Settings'
"400":
description: Bad Request
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"404":
description: Not Found
schema:
$ref: '#/definitions/handlers.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/handlers.ErrorResponse'
summary: Enable model for chat/memory/embedding
tags:
- models
/models/model/{modelId}:
delete:
description: Delete a model configuration by its model_id field (e.g., gpt-4)