merge(github/main): integrate fx dependency injection framework

Merge upstream fx refactor and adapt all services to use go.uber.org/fx
for dependency injection. Resolve conflicts in main.go, server.go,
and service constructors while preserving our domain model changes.

- Fix telegram adapter panic on shutdown (double close channel)
- Fix feishu adapter processing messages after stop
- Increase directory lookup timeout from 2s to 5s
This commit is contained in:
BBQ
2026-02-12 15:49:17 +08:00
14 changed files with 578 additions and 752 deletions
+45
View File
@@ -0,0 +1,45 @@
package boot
import (
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/memohai/memoh/internal/config"
)
type RuntimeConfig struct {
JwtSecret string
JwtExpiresIn time.Duration
ServerAddr string
ContainerdSocketPath string
}
func ProvideRuntimeConfig(cfg config.Config) (*RuntimeConfig, error) {
if strings.TrimSpace(cfg.Auth.JWTSecret) == "" {
return nil, errors.New("jwt secret is required")
}
jwtExpiresIn, err := time.ParseDuration(cfg.Auth.JWTExpiresIn)
if err != nil {
return nil, fmt.Errorf("invalid jwt expires in: %w", err)
}
ret := &RuntimeConfig{
JwtSecret: cfg.Auth.JWTSecret,
JwtExpiresIn: jwtExpiresIn,
ServerAddr: cfg.Server.Addr,
ContainerdSocketPath: cfg.Containerd.SocketPath,
}
if value := os.Getenv("HTTP_ADDR"); value != "" {
ret.ServerAddr = value
}
if value := os.Getenv("CONTAINERD_SOCKET"); value != "" {
ret.ContainerdSocketPath = value
}
return ret, nil
}
@@ -279,6 +279,9 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig,
feishuCfg.EncryptKey,
)
eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error {
if connCtx.Err() != nil {
return nil
}
msg := extractFeishuInbound(event)
text := msg.Message.PlainText()
rawMessageID := ""
@@ -165,10 +165,6 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig
for {
select {
case <-connCtx.Done():
if a.logger != nil {
a.logger.Info("stop", slog.String("config_id", cfg.ID))
}
bot.StopReceivingUpdates()
return
case update, ok := <-updates:
if !ok {
@@ -251,12 +247,19 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig
}
}()
stop := func(context.Context) error {
stop := func(_ context.Context) error {
if a.logger != nil {
a.logger.Info("stop", slog.String("config_id", cfg.ID))
}
cancel()
bot.StopReceivingUpdates()
cancel()
// Drain remaining updates so the library's polling goroutine can
// finish writing and exit. Without this, the in-flight long-poll
// HTTP request keeps the old getUpdates session alive, causing
// "Conflict: terminated by other getUpdates request" when a new
// connection starts with the same bot token.
for range updates {
}
return nil
}
return channel.NewConnection(cfg, stop), nil
-128
View File
@@ -1,128 +0,0 @@
package channel
import (
"context"
"fmt"
"log/slog"
"testing"
)
// mockAdapter 专门用于 Manager 路由测试
type mockAdapter struct {
sentMessages []OutboundMessage
}
func (m *mockAdapter) Type() ChannelType { return ChannelType("test") }
func (m *mockAdapter) Descriptor() Descriptor {
return Descriptor{Type: ChannelType("test"), DisplayName: "Test", Capabilities: ChannelCapabilities{Text: true}}
}
func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error {
m.sentMessages = append(m.sentMessages, msg)
return nil
}
type fakeInboundProcessor struct {
resp *OutboundMessage
err error
gotCfg ChannelConfig
gotMsg InboundMessage
}
func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender ReplySender) error {
f.gotCfg = cfg
f.gotMsg = msg
if f.err != nil {
return f.err
}
if f.resp == nil {
return nil
}
if sender == nil {
return fmt.Errorf("sender missing")
}
return sender.Send(ctx, *f.resp)
}
func TestManager_HandleInbound_CoreLogic(t *testing.T) {
logger := slog.Default()
t.Run("返回回复_发送成功", func(t *testing.T) {
processor := &fakeInboundProcessor{
resp: &OutboundMessage{
Target: "target-id",
Message: Message{
Text: "AI回复内容",
},
},
}
reg := NewRegistry()
m := NewManager(logger, reg, &fakeConfigStore{}, processor)
adapter := &mockAdapter{}
m.RegisterAdapter(adapter)
cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")}
msg := InboundMessage{
Channel: ChannelType("test"),
Message: Message{Text: "你好"},
ReplyTarget: "target-id",
Conversation: Conversation{
ID: "chat-1",
Type: "p2p",
},
}
err := m.handleInbound(context.Background(), cfg, msg)
if err != nil {
t.Fatalf("不应报错: %v", err)
}
// 验证: 是否正确调用了 Adapter 发送回复
if len(adapter.sentMessages) != 1 {
t.Fatalf("应该发送 1 条回复,实际发送: %d", len(adapter.sentMessages))
}
if adapter.sentMessages[0].Message.PlainText() != "AI回复内容" {
t.Errorf("回复内容错误: %s", adapter.sentMessages[0].Message.PlainText())
}
if adapter.sentMessages[0].Target != "target-id" {
t.Errorf("回复目标错误: %s", adapter.sentMessages[0].Target)
}
})
t.Run("无回复_不发送", func(t *testing.T) {
processor := &fakeInboundProcessor{resp: nil}
reg := NewRegistry()
m := NewManager(logger, reg, &fakeConfigStore{}, processor)
adapter := &mockAdapter{}
m.RegisterAdapter(adapter)
cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")}
msg := InboundMessage{
Channel: ChannelType("test"),
Message: Message{Text: "你好"},
ReplyTarget: "target-id",
}
err := m.handleInbound(context.Background(), cfg, msg)
if err != nil {
t.Fatalf("不应报错: %v", err)
}
if len(adapter.sentMessages) != 0 {
t.Errorf("不应发送回复,实际发送: %+v", adapter.sentMessages)
}
})
t.Run("处理失败_返回错误", func(t *testing.T) {
processor := &fakeInboundProcessor{err: context.Canceled}
reg := NewRegistry()
m := NewManager(logger, reg, &fakeConfigStore{}, processor)
cfg := ChannelConfig{ID: "bot-1"}
msg := InboundMessage{Message: Message{Text: " "}} // 空格消息
err := m.handleInbound(context.Background(), cfg, msg)
if err == nil {
t.Errorf("应返回处理错误")
}
})
}
+3 -1
View File
@@ -26,6 +26,7 @@ import (
"github.com/containerd/containerd/v2/pkg/oci"
"github.com/containerd/errdefs"
"github.com/containerd/platforms"
"github.com/memohai/memoh/internal/config"
"github.com/opencontainers/go-digest"
"github.com/opencontainers/image-spec/identity"
"github.com/opencontainers/runtime-spec/specs-go"
@@ -147,7 +148,8 @@ type DefaultService struct {
logger *slog.Logger
}
func NewDefaultService(log *slog.Logger, client *containerd.Client, namespace string) *DefaultService {
func NewDefaultService(log *slog.Logger, client *containerd.Client, cfg config.Config) *DefaultService {
namespace := cfg.Containerd.Namespace
if namespace == "" {
namespace = DefaultNamespace
}
+3 -7
View File
@@ -60,7 +60,7 @@ type Manager struct {
logger *slog.Logger
}
func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string) *Manager {
func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, conn *pgxpool.Pool) *Manager {
if namespace == "" {
namespace = config.DefaultNamespace
}
@@ -68,6 +68,8 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, nam
service: service,
cfg: cfg,
namespace: namespace,
db: conn,
queries: dbsqlc.New(conn),
logger: log.With(slog.String("component", "mcp")),
containerID: func(botID string) string {
return ContainerPrefix + botID
@@ -75,12 +77,6 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, nam
}
}
func (m *Manager) WithDB(db *pgxpool.Pool) *Manager {
m.db = db
m.queries = dbsqlc.New(db)
return m
}
func (m *Manager) Init(ctx context.Context) error {
image := DefaultImageRef
+1 -1
View File
@@ -494,7 +494,7 @@ func (r *IdentityResolver) resolveDisplayNameFromDirectory(ctx context.Context,
if ctx == nil {
ctx = context.Background()
}
lookupCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
entry, err := directoryAdapter.ResolveEntry(lookupCtx, cfg, subjectID, channel.DirectoryEntryUser)
if err != nil {
+3 -2
View File
@@ -14,6 +14,7 @@ import (
"github.com/robfig/cron/v3"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/boot"
"github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
@@ -29,7 +30,7 @@ type Service struct {
jobs map[string]cron.EntryID
}
func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jwtSecret string) *Service {
func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, runtimeConfig *boot.RuntimeConfig) *Service {
parser := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
c := cron.New(cron.WithParser(parser))
service := &Service{
@@ -37,7 +38,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, jw
cron: c,
parser: parser,
triggerer: triggerer,
jwtSecret: jwtSecret,
jwtSecret: runtimeConfig.JwtSecret,
logger: log.With(slog.String("service", "schedule")),
jobs: map[string]cron.EntryID{},
}
+16 -58
View File
@@ -1,6 +1,7 @@
package server
import (
"context"
"log/slog"
"strings"
@@ -8,7 +9,6 @@ import (
"github.com/labstack/echo/v4/middleware"
"github.com/memohai/memoh/internal/auth"
"github.com/memohai/memoh/internal/handlers"
)
type Server struct {
@@ -17,7 +17,13 @@ type Server struct {
logger *slog.Logger
}
func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *handlers.PingHandler, authHandler *handlers.AuthHandler, memoryHandler *handlers.MemoryHandler, embeddingsHandler *handlers.EmbeddingsHandler, conversationHandler *handlers.MessageHandler, swaggerHandler *handlers.SwaggerHandler, providersHandler *handlers.ProvidersHandler, modelsHandler *handlers.ModelsHandler, settingsHandler *handlers.SettingsHandler, preauthHandler *handlers.PreauthHandler, bindHandler *handlers.BindHandler, scheduleHandler *handlers.ScheduleHandler, subagentHandler *handlers.SubagentHandler, containerdHandler *handlers.ContainerdHandler, channelHandler *handlers.ChannelHandler, usersHandler *handlers.UsersHandler, mcpHandler *handlers.MCPHandler, cliHandler *handlers.LocalChannelHandler, webHandler *handlers.LocalChannelHandler) *Server {
type Handler interface {
Register(e *echo.Echo)
}
func NewServer(log *slog.Logger, addr string, jwtSecret string,
handlers ...Handler,
) *Server {
if addr == "" {
addr = ":8080"
}
@@ -51,62 +57,10 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
return false
}))
if pingHandler != nil {
pingHandler.Register(e)
}
if authHandler != nil {
authHandler.Register(e)
}
if memoryHandler != nil {
memoryHandler.Register(e)
}
if embeddingsHandler != nil {
embeddingsHandler.Register(e)
}
if conversationHandler != nil {
conversationHandler.Register(e)
}
if swaggerHandler != nil {
swaggerHandler.Register(e)
}
if settingsHandler != nil {
settingsHandler.Register(e)
}
if preauthHandler != nil {
preauthHandler.Register(e)
}
if bindHandler != nil {
bindHandler.Register(e)
}
if scheduleHandler != nil {
scheduleHandler.Register(e)
}
if subagentHandler != nil {
subagentHandler.Register(e)
}
if providersHandler != nil {
providersHandler.Register(e)
}
if modelsHandler != nil {
modelsHandler.Register(e)
}
if containerdHandler != nil {
containerdHandler.Register(e)
}
if channelHandler != nil {
channelHandler.Register(e)
}
if usersHandler != nil {
usersHandler.Register(e)
}
if mcpHandler != nil {
mcpHandler.Register(e)
}
if cliHandler != nil {
cliHandler.Register(e)
}
if webHandler != nil {
webHandler.Register(e)
for _, h := range handlers {
if h != nil {
h.Register(e)
}
}
return &Server{
@@ -119,3 +73,7 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, pingHandler *han
func (s *Server) Start() error {
return s.echo.Start(s.addr)
}
func (s *Server) Stop(ctx context.Context) error {
return s.echo.Shutdown(ctx)
}