mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
merge(github/main): integrate fx dependency injection framework
Merge upstream fx refactor and adapt all services to use go.uber.org/fx for dependency injection. Resolve conflicts in main.go, server.go, and service constructors while preserving our domain model changes. - Fix telegram adapter panic on shutdown (double close channel) - Fix feishu adapter processing messages after stop - Increase directory lookup timeout from 2s to 5s
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
package boot
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
)
|
||||
|
||||
type RuntimeConfig struct {
|
||||
JwtSecret string
|
||||
JwtExpiresIn time.Duration
|
||||
ServerAddr string
|
||||
ContainerdSocketPath string
|
||||
}
|
||||
|
||||
func ProvideRuntimeConfig(cfg config.Config) (*RuntimeConfig, error) {
|
||||
if strings.TrimSpace(cfg.Auth.JWTSecret) == "" {
|
||||
return nil, errors.New("jwt secret is required")
|
||||
}
|
||||
|
||||
jwtExpiresIn, err := time.ParseDuration(cfg.Auth.JWTExpiresIn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid jwt expires in: %w", err)
|
||||
}
|
||||
|
||||
ret := &RuntimeConfig{
|
||||
JwtSecret: cfg.Auth.JWTSecret,
|
||||
JwtExpiresIn: jwtExpiresIn,
|
||||
ServerAddr: cfg.Server.Addr,
|
||||
ContainerdSocketPath: cfg.Containerd.SocketPath,
|
||||
}
|
||||
|
||||
if value := os.Getenv("HTTP_ADDR"); value != "" {
|
||||
ret.ServerAddr = value
|
||||
}
|
||||
|
||||
if value := os.Getenv("CONTAINERD_SOCKET"); value != "" {
|
||||
ret.ContainerdSocketPath = value
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
@@ -279,6 +279,9 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig,
|
||||
feishuCfg.EncryptKey,
|
||||
)
|
||||
eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
if connCtx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
msg := extractFeishuInbound(event)
|
||||
text := msg.Message.PlainText()
|
||||
rawMessageID := ""
|
||||
|
||||
@@ -165,10 +165,6 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig
|
||||
for {
|
||||
select {
|
||||
case <-connCtx.Done():
|
||||
if a.logger != nil {
|
||||
a.logger.Info("stop", slog.String("config_id", cfg.ID))
|
||||
}
|
||||
bot.StopReceivingUpdates()
|
||||
return
|
||||
case update, ok := <-updates:
|
||||
if !ok {
|
||||
@@ -251,12 +247,19 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig
|
||||
}
|
||||
}()
|
||||
|
||||
stop := func(context.Context) error {
|
||||
stop := func(_ context.Context) error {
|
||||
if a.logger != nil {
|
||||
a.logger.Info("stop", slog.String("config_id", cfg.ID))
|
||||
}
|
||||
cancel()
|
||||
bot.StopReceivingUpdates()
|
||||
cancel()
|
||||
// Drain remaining updates so the library's polling goroutine can
|
||||
// finish writing and exit. Without this, the in-flight long-poll
|
||||
// HTTP request keeps the old getUpdates session alive, causing
|
||||
// "Conflict: terminated by other getUpdates request" when a new
|
||||
// connection starts with the same bot token.
|
||||
for range updates {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return channel.NewConnection(cfg, stop), nil
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockAdapter 专门用于 Manager 路由测试
|
||||
type mockAdapter struct {
|
||||
sentMessages []OutboundMessage
|
||||
}
|
||||
|
||||
func (m *mockAdapter) Type() ChannelType { return ChannelType("test") }
|
||||
func (m *mockAdapter) Descriptor() Descriptor {
|
||||
return Descriptor{Type: ChannelType("test"), DisplayName: "Test", Capabilities: ChannelCapabilities{Text: true}}
|
||||
}
|
||||
func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error {
|
||||
m.sentMessages = append(m.sentMessages, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeInboundProcessor struct {
|
||||
resp *OutboundMessage
|
||||
err error
|
||||
gotCfg ChannelConfig
|
||||
gotMsg InboundMessage
|
||||
}
|
||||
|
||||
func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error {
|
||||
f.gotCfg = cfg
|
||||
f.gotMsg = msg
|
||||
if f.err != nil {
|
||||
return f.err
|
||||
}
|
||||
if f.resp == nil {
|
||||
return nil
|
||||
}
|
||||
if sender == nil {
|
||||
return fmt.Errorf("sender missing")
|
||||
}
|
||||
return sender.Send(ctx, *f.resp)
|
||||
}
|
||||
|
||||
func TestManager_HandleInbound_CoreLogic(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
|
||||
t.Run("返回回复_发送成功", func(t *testing.T) {
|
||||
processor := &fakeInboundProcessor{
|
||||
resp: &OutboundMessage{
|
||||
Target: "target-id",
|
||||
Message: Message{
|
||||
Text: "AI回复内容",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reg := NewRegistry()
|
||||
m := NewManager(logger, reg, &fakeConfigStore{}, processor)
|
||||
adapter := &mockAdapter{}
|
||||
m.RegisterAdapter(adapter)
|
||||
|
||||
cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")}
|
||||
msg := InboundMessage{
|
||||
Channel: ChannelType("test"),
|
||||
Message: Message{Text: "你好"},
|
||||
ReplyTarget: "target-id",
|
||||
Conversation: Conversation{
|
||||
ID: "chat-1",
|
||||
Type: "p2p",
|
||||
},
|
||||
}
|
||||
|
||||
err := m.handleInbound(context.Background(), cfg, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("不应报错: %v", err)
|
||||
}
|
||||
|
||||
// 验证: 是否正确调用了 Adapter 发送回复
|
||||
if len(adapter.sentMessages) != 1 {
|
||||
t.Fatalf("应该发送 1 条回复,实际发送: %d", len(adapter.sentMessages))
|
||||
}
|
||||
if adapter.sentMessages[0].Message.PlainText() != "AI回复内容" {
|
||||
t.Errorf("回复内容错误: %s", adapter.sentMessages[0].Message.PlainText())
|
||||
}
|
||||
if adapter.sentMessages[0].Target != "target-id" {
|
||||
t.Errorf("回复目标错误: %s", adapter.sentMessages[0].Target)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("无回复_不发送", func(t *testing.T) {
|
||||
processor := &fakeInboundProcessor{resp: nil}
|
||||
reg := NewRegistry()
|
||||
m := NewManager(logger, reg, &fakeConfigStore{}, processor)
|
||||
adapter := &mockAdapter{}
|
||||
m.RegisterAdapter(adapter)
|
||||
|
||||
cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")}
|
||||
msg := InboundMessage{
|
||||
Channel: ChannelType("test"),
|
||||
Message: Message{Text: "你好"},
|
||||
ReplyTarget: "target-id",
|
||||
}
|
||||
|
||||
err := m.handleInbound(context.Background(), cfg, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("不应报错: %v", err)
|
||||
}
|
||||
|
||||
if len(adapter.sentMessages) != 0 {
|
||||
t.Errorf("不应发送回复,实际发送: %+v", adapter.sentMessages)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("处理失败_返回错误", func(t *testing.T) {
|
||||
processor := &fakeInboundProcessor{err: context.Canceled}
|
||||
reg := NewRegistry()
|
||||
m := NewManager(logger, reg, &fakeConfigStore{}, processor)
|
||||
cfg := ChannelConfig{ID: "bot-1"}
|
||||
msg := InboundMessage{Message: Message{Text: " "}} // 空格消息
|
||||
|
||||
err := m.handleInbound(context.Background(), cfg, msg)
|
||||
if err == nil {
|
||||
t.Errorf("应返回处理错误")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/containerd/containerd/v2/pkg/oci"
|
||||
"github.com/containerd/errdefs"
|
||||
"github.com/containerd/platforms"
|
||||
"github.com/memohai/memoh/internal/config"
|
||||
"github.com/opencontainers/go-digest"
|
||||
"github.com/opencontainers/image-spec/identity"
|
||||
"github.com/opencontainers/runtime-spec/specs-go"
|
||||
@@ -147,7 +148,8 @@ type DefaultService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewDefaultService(log *slog.Logger, client *containerd.Client, namespace string) *DefaultService {
|
||||
func NewDefaultService(log *slog.Logger, client *containerd.Client, cfg config.Config) *DefaultService {
|
||||
namespace := cfg.Containerd.Namespace
|
||||
if namespace == "" {
|
||||
namespace = DefaultNamespace
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ type Manager struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string) *Manager {
|
||||
func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, conn *pgxpool.Pool) *Manager {
|
||||
if namespace == "" {
|
||||
namespace = config.DefaultNamespace
|
||||
}
|
||||
@@ -68,6 +68,8 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, nam
|
||||
service: service,
|
||||
cfg: cfg,
|
||||
namespace: namespace,
|
||||
db: conn,
|
||||
queries: dbsqlc.New(conn),
|
||||
logger: log.With(slog.String("component", "mcp")),
|
||||
containerID: func(botID string) string {
|
||||
return ContainerPrefix + botID
|
||||
@@ -75,12 +77,6 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, nam
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) WithDB(db *pgxpool.Pool) *Manager {
|
||||
m.db = db
|
||||
m.queries = dbsqlc.New(db)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Manager) Init(ctx context.Context) error {
|
||||
image := DefaultImageRef
|
||||
|
||||
|
||||
@@ -494,7 +494,7 @@ func (r *IdentityResolver) resolveDisplayNameFromDirectory(ctx context.Context,
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
entry, err := directoryAdapter.ResolveEntry(lookupCtx, cfg, subjectID, channel.DirectoryEntryUser)
|
||||
if err != nil {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/boot"
|
||||
"github.com/memohai/memoh/internal/db"
|
||||
"github.com/memohai/memoh/internal/db/sqlc"
|
||||
)
|
||||
@@ -29,7 +30,7 @@ type Service struct {
|
||||
jobs map[string]cron.EntryID
|
||||
}
|
||||
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jwtSecret string) *Service {
|
||||
func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, runtimeConfig *boot.RuntimeConfig) *Service {
|
||||
parser := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
|
||||
c := cron.New(cron.WithParser(parser))
|
||||
service := &Service{
|
||||
@@ -37,7 +38,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jw
|
||||
cron: c,
|
||||
parser: parser,
|
||||
triggerer: triggerer,
|
||||
jwtSecret: jwtSecret,
|
||||
jwtSecret: runtimeConfig.JwtSecret,
|
||||
logger: log.With(slog.String("service", "schedule")),
|
||||
jobs: map[string]cron.EntryID{},
|
||||
}
|
||||
|
||||
+16
-58
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
@@ -8,7 +9,6 @@ import (
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/memohai/memoh/internal/auth"
|
||||
"github.com/memohai/memoh/internal/handlers"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -17,7 +17,13 @@ type Server struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, conversationHandler *handlers.MessageHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, preauthHandler *handlers.PreauthHandler, bindHandler *handlers.BindHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server {
|
||||
type Handler interface {
|
||||
Register(e *echo.Echo)
|
||||
}
|
||||
|
||||
func NewServer(log *slog.Logger, addr string, jwtSecret string,
|
||||
handlers ...Handler,
|
||||
) *Server {
|
||||
if addr == "" {
|
||||
addr = ":8080"
|
||||
}
|
||||
@@ -51,62 +57,10 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
|
||||
return false
|
||||
}))
|
||||
|
||||
if pingHandler != nil {
|
||||
pingHandler.Register(e)
|
||||
}
|
||||
if authHandler != nil {
|
||||
authHandler.Register(e)
|
||||
}
|
||||
if memoryHandler != nil {
|
||||
memoryHandler.Register(e)
|
||||
}
|
||||
if embeddingsHandler != nil {
|
||||
embeddingsHandler.Register(e)
|
||||
}
|
||||
if conversationHandler != nil {
|
||||
conversationHandler.Register(e)
|
||||
}
|
||||
if swaggerHandler != nil {
|
||||
swaggerHandler.Register(e)
|
||||
}
|
||||
if settingsHandler != nil {
|
||||
settingsHandler.Register(e)
|
||||
}
|
||||
if preauthHandler != nil {
|
||||
preauthHandler.Register(e)
|
||||
}
|
||||
if bindHandler != nil {
|
||||
bindHandler.Register(e)
|
||||
}
|
||||
if scheduleHandler != nil {
|
||||
scheduleHandler.Register(e)
|
||||
}
|
||||
if subagentHandler != nil {
|
||||
subagentHandler.Register(e)
|
||||
}
|
||||
if providersHandler != nil {
|
||||
providersHandler.Register(e)
|
||||
}
|
||||
if modelsHandler != nil {
|
||||
modelsHandler.Register(e)
|
||||
}
|
||||
if containerdHandler != nil {
|
||||
containerdHandler.Register(e)
|
||||
}
|
||||
if channelHandler != nil {
|
||||
channelHandler.Register(e)
|
||||
}
|
||||
if usersHandler != nil {
|
||||
usersHandler.Register(e)
|
||||
}
|
||||
if mcpHandler != nil {
|
||||
mcpHandler.Register(e)
|
||||
}
|
||||
if cliHandler != nil {
|
||||
cliHandler.Register(e)
|
||||
}
|
||||
if webHandler != nil {
|
||||
webHandler.Register(e)
|
||||
for _, h := range handlers {
|
||||
if h != nil {
|
||||
h.Register(e)
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
@@ -119,3 +73,7 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
|
||||
func (s *Server) Start() error {
|
||||
return s.echo.Start(s.addr)
|
||||
}
|
||||
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
return s.echo.Shutdown(ctx)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user