diff --git a/.github/workflows/go-ci.yml b/.github/workflows/go-ci.yml new file mode 100644 index 00000000..fd990caf --- /dev/null +++ b/.github/workflows/go-ci.yml @@ -0,0 +1,69 @@ +name: Go CI + +on: + push: + branches: ["main"] + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".golangci.yml" + - "db/queries/**/*.sql" + - "db/migrations/**/*.sql" + - ".github/workflows/go-ci.yml" + pull_request: + branches: ["main"] + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".golangci.yml" + - "db/queries/**/*.sql" + - "db/migrations/**/*.sql" + - ".github/workflows/go-ci.yml" + workflow_dispatch: + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: v2.10.1 + only-new-issues: true + + test: + name: Test + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: Run tests + run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... + - name: Upload coverage + uses: actions/upload-artifact@v4 + with: + name: go-coverage + path: coverage.txt + if-no-files-found: error diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..096527da --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,107 @@ +version: "2" + +run: + timeout: 15m + go: "1.25" + tests: true + +output: + show-stats: false + +issues: + max-issues-per-linter: 0 + max-same-issues: 0 + +linters: + default: none + enable: + - bodyclose + - contextcheck + - errcheck + - errorlint + - exptostd + - fatcontext + - gocritic + - gosec + - godot + - govet + - ineffassign + - misspell + - noctx + - nilnesserr + - perfsprint + - predeclared + - revive + - sqlclosecheck + - sloglint + - staticcheck + - testifylint + - unconvert + - unused + - usestdlibvars + - whitespace + exclusions: + paths: + - internal/db/sqlc + - ^.*\.(pb|l|y)\.go$ + settings: + govet: + enable-all: true + disable: + - shadow + - fieldalignment + gocyclo: + min-complexity: 10 + funlen: + lines: 60 + statements: 30 + perfsprint: + int-conversion: true + err-error: true + errorf: true + sprintf1: true + strconcat: false + revive: + rules: + - name: blank-imports + - name: comment-spacings + - name: context-as-argument + arguments: + - allowTypesBefore: "*testing.T,testing.TB" + - name: dot-imports + - name: error-naming + - name: error-return + - name: error-strings + - name: increment-decrement + - name: var-declaration + - name: unreachable-code + - name: unused-parameter + - name: unused-receiver + sloglint: + attr-only: true + no-global: default + static-msg: true + key-naming-case: snake + forbidden-keys: [time, level, msg, source] + testifylint: + enable-all: true + disable: + - float-compare + - go-require + +formatters: + enable: + - gci + - gofumpt + - goimports + settings: + gci: + sections: + - standard + - default + - prefix(github.com/memohai/memoh) + gofumpt: + extra-rules: false + goimports: + local-prefixes: + - github.com/memohai/memoh diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 3688f5f5..fe4e86f0 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -37,21 +37,25 @@ import ( "github.com/memohai/memoh/internal/conversation/flow" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" + emailpkg "github.com/memohai/memoh/internal/email" + emailgeneric "github.com/memohai/memoh/internal/email/adapters/generic" + emailmailgun "github.com/memohai/memoh/internal/email/adapters/mailgun" "github.com/memohai/memoh/internal/handlers" "github.com/memohai/memoh/internal/healthcheck" channelchecker "github.com/memohai/memoh/internal/healthcheck/checkers/channel" mcpchecker "github.com/memohai/memoh/internal/healthcheck/checkers/mcp" modelchecker "github.com/memohai/memoh/internal/healthcheck/checkers/model" + "github.com/memohai/memoh/internal/heartbeat" "github.com/memohai/memoh/internal/inbox" "github.com/memohai/memoh/internal/logger" "github.com/memohai/memoh/internal/mcp" mcpcontacts "github.com/memohai/memoh/internal/mcp/providers/contacts" mcpcontainer "github.com/memohai/memoh/internal/mcp/providers/container" + mcpemail "github.com/memohai/memoh/internal/mcp/providers/email" mcpinbox "github.com/memohai/memoh/internal/mcp/providers/inbox" mcpmemory "github.com/memohai/memoh/internal/mcp/providers/memory" mcpmessage "github.com/memohai/memoh/internal/mcp/providers/message" mcpschedule "github.com/memohai/memoh/internal/mcp/providers/schedule" - mcpemail "github.com/memohai/memoh/internal/mcp/providers/email" mcpweb "github.com/memohai/memoh/internal/mcp/providers/web" mcpfederation "github.com/memohai/memoh/internal/mcp/sources/federation" "github.com/memohai/memoh/internal/media" @@ -62,11 +66,7 @@ import ( "github.com/memohai/memoh/internal/policy" "github.com/memohai/memoh/internal/preauth" "github.com/memohai/memoh/internal/providers" - "github.com/memohai/memoh/internal/heartbeat" "github.com/memohai/memoh/internal/schedule" - emailpkg "github.com/memohai/memoh/internal/email" - emailgeneric "github.com/memohai/memoh/internal/email/adapters/generic" - emailmailgun "github.com/memohai/memoh/internal/email/adapters/mailgun" "github.com/memohai/memoh/internal/searchproviders" "github.com/memohai/memoh/internal/server" "github.com/memohai/memoh/internal/settings" @@ -281,7 +281,7 @@ func provideContainerService(lc fx.Lifecycle, log *slog.Logger, cfg config.Confi return nil, err } lc.Append(fx.Hook{ - OnStop: func(ctx context.Context) error { + OnStop: func(_ context.Context) error { cleanup() return nil }, @@ -295,7 +295,7 @@ func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) { return nil, fmt.Errorf("db connect: %w", err) } lc.Append(fx.Hook{ - OnStop: func(ctx context.Context) error { + OnStop: func(_ context.Context) error { conn.Close() return nil }, @@ -327,7 +327,7 @@ func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, lo func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, manager *mcp.Manager) *memprovider.Registry { registry := memprovider.NewRegistry(log) builtinRuntime := handlers.NewBuiltinMemoryRuntime(manager) - registry.RegisterFactory(memprovider.BuiltinType, func(id string, config map[string]any) (memprovider.Provider, error) { + registry.RegisterFactory(memprovider.BuiltinType, func(_ string, _ map[string]any) (memprovider.Provider, error) { return memprovider.NewBuiltinProvider(log, builtinRuntime, chatService, accountService), nil }) registry.Register("__builtin_default__", memprovider.NewBuiltinProvider(log, builtinRuntime, chatService, accountService)) @@ -436,6 +436,7 @@ func provideContainerdHandler(log *slog.Logger, service ctr.Service, manager *mc func provideFederationGateway(log *slog.Logger, containerdHandler *handlers.ContainerdHandler) *handlers.MCPFederationGateway { return handlers.NewMCPFederationGateway(log, containerdHandler) } + func provideOAuthService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *mcp.OAuthService { addr := strings.TrimSpace(cfg.Server.Addr) if addr == "" { @@ -448,7 +449,8 @@ func provideOAuthService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.C callbackURL := "http://" + host + "/api/oauth/mcp/callback" return mcp.NewOAuthService(log, queries, callbackURL) } -func provideToolGatewayService(log *slog.Logger, cfg config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, chatService *conversation.Service, accountService *accounts.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *mcp.Manager, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService, mediaService *media.Service, inboxService *inbox.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, oauthService *mcp.OAuthService) *mcp.ToolGatewayService { + +func provideToolGatewayService(log *slog.Logger, _ config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, _ *conversation.Service, _ *accounts.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *mcp.Manager, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService, mediaService *media.Service, inboxService *inbox.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, oauthService *mcp.OAuthService) *mcp.ToolGatewayService { fedGateway.SetOAuthService(oauthService) var assetResolver mcpmessage.AssetResolver if mediaService != nil { @@ -477,7 +479,7 @@ func provideToolGatewayService(log *slog.Logger, cfg config.Config, channelManag // handler providers (interface adaptation / config extraction) // --------------------------------------------------------------------------- -func provideMemoryHandler(log *slog.Logger, botService *bots.Service, accountService *accounts.Service, cfg config.Config, manager *mcp.Manager, memoryRegistry *memprovider.Registry, settingsService *settings.Service, containerdHandler *handlers.ContainerdHandler) *handlers.MemoryHandler { +func provideMemoryHandler(log *slog.Logger, botService *bots.Service, accountService *accounts.Service, _ config.Config, manager *mcp.Manager, memoryRegistry *memprovider.Registry, settingsService *settings.Service, _ *handlers.ContainerdHandler) *handlers.MemoryHandler { h := handlers.NewMemoryHandler(log, botService, accountService) h.SetMemoryRegistry(memoryRegistry) h.SetSettingsService(settingsService) @@ -526,6 +528,7 @@ func provideEmailRegistry(log *slog.Logger) *emailpkg.Registry { func provideEmailChatGateway(resolver *flow.Resolver, queries *dbsqlc.Queries, cfg config.Config, log *slog.Logger) emailpkg.ChatTriggerer { return flow.NewEmailChatGateway(resolver, queries, cfg.Auth.JWTSecret, log) } + func provideEmailTrigger(log *slog.Logger, service *emailpkg.Service, botInbox *inbox.Service, chatTriggerer emailpkg.ChatTriggerer) *emailpkg.Trigger { return emailpkg.NewTrigger(log, service, botInbox, chatTriggerer) } @@ -541,9 +544,9 @@ func startEmailManager(lc fx.Lifecycle, emailManager *emailpkg.Manager) { }() return nil }, - OnStop: func(_ context.Context) error { + OnStop: func(stopCtx context.Context) error { cancel() - emailManager.Stop() + emailManager.Stop(stopCtx) return nil }, }) @@ -677,7 +680,7 @@ func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *server.Server, shutd func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error { if queries == nil { - return fmt.Errorf("db queries not configured") + return errors.New("db queries not configured") } count, err := queries.CountAccounts(ctx) if err != nil { @@ -691,7 +694,7 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer password := strings.TrimSpace(cfg.Admin.Password) email := strings.TrimSpace(cfg.Admin.Email) if username == "" || password == "" { - return fmt.Errorf("admin username/password required in config.toml") + return errors.New("admin username/password required in config.toml") } if password == "change-your-password-here" { log.Warn("admin password uses default placeholder; please update config.toml") @@ -780,7 +783,7 @@ func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) { if c.modelsService == nil || c.queries == nil { - return nil, fmt.Errorf("models service not configured") + return nil, errors.New("models service not configured") } botID := "" memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, botID) @@ -795,7 +798,7 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) { } _ = memoryProvider _ = memoryModel - return nil, fmt.Errorf("memory llm runtime is not available") + return nil, errors.New("memory llm runtime is not available") } // skillLoaderAdapter bridges handlers.ContainerdHandler to flow.SkillLoader. @@ -827,7 +830,7 @@ type mediaAssetResolverAdapter struct { func (a *mediaAssetResolverAdapter) GetByStorageKey(ctx context.Context, botID, storageKey string) (mcpmessage.AssetMeta, error) { if a == nil || a.media == nil { - return mcpmessage.AssetMeta{}, fmt.Errorf("media service not configured") + return mcpmessage.AssetMeta{}, errors.New("media service not configured") } asset, err := a.media.GetByStorageKey(ctx, botID, storageKey) if err != nil { @@ -843,7 +846,7 @@ func (a *mediaAssetResolverAdapter) GetByStorageKey(ctx context.Context, botID, func (a *mediaAssetResolverAdapter) IngestContainerFile(ctx context.Context, botID, containerPath string) (mcpmessage.AssetMeta, error) { if a == nil || a.media == nil { - return mcpmessage.AssetMeta{}, fmt.Errorf("media service not configured") + return mcpmessage.AssetMeta{}, errors.New("media service not configured") } asset, err := a.media.IngestContainerFile(ctx, botID, containerPath) if err != nil { @@ -864,7 +867,7 @@ type gatewayAssetLoaderAdapter struct { func (a *gatewayAssetLoaderAdapter) OpenForGateway(ctx context.Context, botID, contentHash string) (io.ReadCloser, string, error) { if a == nil || a.media == nil { - return nil, "", fmt.Errorf("media service not configured") + return nil, "", errors.New("media service not configured") } reader, asset, err := a.media.Open(ctx, botID, contentHash) if err != nil { diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index a2f28684..d494211a 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -10,10 +10,11 @@ import ( "path/filepath" "syscall" - "github.com/memohai/memoh/internal/logger" - pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + + "github.com/memohai/memoh/internal/logger" + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" ) const ( @@ -23,7 +24,7 @@ const ( // initDataDir ensures /data exists and seeds template files on first boot. func initDataDir() { - if err := os.MkdirAll(defaultWorkDir, 0o755); err != nil { + if err := os.MkdirAll(defaultWorkDir, 0o750); err != nil { logger.Warn("failed to create data dir", slog.Any("error", err)) return } @@ -64,10 +65,10 @@ func main() { addr = defaultListenAddr } - lis, err := net.Listen("tcp", addr) + lis, err := (&net.ListenConfig{}).Listen(ctx, "tcp", addr) if err != nil { logger.Error("failed to listen", slog.String("addr", addr), slog.Any("error", err)) - os.Exit(1) + return } srv := grpc.NewServer() @@ -76,13 +77,13 @@ func main() { go func() { <-ctx.Done() - logger.Info("shutting down gRPC server") + logger.FromContext(ctx).Info("shutting down gRPC server") srv.GracefulStop() }() logger.Info("mcp gRPC server listening", slog.String("addr", addr)) if err := srv.Serve(lis); err != nil { logger.Error("gRPC server failed", slog.Any("error", err)) - os.Exit(1) + return } } diff --git a/cmd/mcp/server.go b/cmd/mcp/server.go index e368ad86..304f08b2 100644 --- a/cmd/mcp/server.go +++ b/cmd/mcp/server.go @@ -4,9 +4,11 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "io/fs" + "math" "os" "os/exec" "path/filepath" @@ -14,9 +16,10 @@ import ( "time" "unicode/utf8" - pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" ) const ( @@ -33,18 +36,18 @@ type containerServer struct { pb.UnimplementedContainerServiceServer } -func (s *containerServer) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) { +func (*containerServer) ReadFile(_ context.Context, req *pb.ReadFileRequest) (*pb.ReadFileResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) - f, err := os.Open(path) + f, err := os.Open(path) //nolint:gosec // G304: MCP container filesystem server; paths are resolved within the container's /data, SSRF is by design if err != nil { return nil, status.Errorf(codes.NotFound, "open: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() probe := make([]byte, binaryProbeBytes) n, _ := f.Read(probe) @@ -108,23 +111,23 @@ func (s *containerServer) ReadFile(_ context.Context, req *pb.ReadFileRequest) ( }, nil } -func (s *containerServer) WriteFile(_ context.Context, req *pb.WriteFileRequest) (*pb.WriteFileResponse, error) { +func (*containerServer) WriteFile(_ context.Context, req *pb.WriteFileRequest) (*pb.WriteFileResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { return nil, status.Errorf(codes.Internal, "mkdir: %v", err) } - if err := os.WriteFile(path, req.GetContent(), 0o644); err != nil { + if err := os.WriteFile(path, req.GetContent(), 0o600); err != nil { return nil, status.Errorf(codes.Internal, "write: %v", err) } return &pb.WriteFileResponse{}, nil } -func (s *containerServer) ListDir(_ context.Context, req *pb.ListDirRequest) (*pb.ListDirResponse, error) { +func (*containerServer) ListDir(_ context.Context, req *pb.ListDirRequest) (*pb.ListDirResponse, error) { dir := req.GetPath() if dir == "" { dir = "." @@ -167,7 +170,7 @@ func (s *containerServer) ListDir(_ context.Context, req *pb.ListDirRequest) (*p return &pb.ListDirResponse{Entries: entries}, nil } -func (s *containerServer) Exec(stream pb.ContainerService_ExecServer) error { +func (*containerServer) Exec(stream pb.ContainerService_ExecServer) error { // Receive first message to get command details firstMsg, err := stream.Recv() if err != nil { @@ -192,7 +195,7 @@ func (s *containerServer) Exec(stream pb.ContainerService_ExecServer) error { ctx, cancel := context.WithTimeout(stream.Context(), time.Duration(timeout)*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) + cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) //nolint:gosec // G204: MCP exec tool intentionally executes agent-issued shell commands inside the container cmd.Dir = workDir if len(firstMsg.GetEnv()) > 0 { cmd.Env = append(os.Environ(), firstMsg.GetEnv()...) @@ -242,8 +245,10 @@ func (s *containerServer) Exec(stream pb.ContainerService_ExecServer) error { exitCode := int32(0) if err := cmd.Wait(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = int32(exitErr.ExitCode()) + exitErr := &exec.ExitError{} + if errors.As(err, &exitErr) { + ec := exitErr.ExitCode() + exitCode = int32(max(math.MinInt32, min(math.MaxInt32, ec))) //nolint:gosec // G115: value is clamped to int32 range above; Unix exit codes are 0-255 } else { exitCode = -1 } @@ -255,18 +260,18 @@ func (s *containerServer) Exec(stream pb.ContainerService_ExecServer) error { }) } -func (s *containerServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { +func (*containerServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerService_ReadRawServer) error { path := req.GetPath() if path == "" { return status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) - f, err := os.Open(path) + f, err := os.Open(path) //nolint:gosec // G304: MCP container filesystem server; path is resolved within the container if err != nil { return status.Errorf(codes.NotFound, "open: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() buf := make([]byte, rawChunkSize) for { @@ -286,13 +291,13 @@ func (s *containerServer) ReadRaw(req *pb.ReadRawRequest, stream pb.ContainerSer return nil } -func (s *containerServer) WriteRaw(stream pb.ContainerService_WriteRawServer) error { +func (*containerServer) WriteRaw(stream pb.ContainerService_WriteRawServer) error { var f *os.File var written int64 for { chunk, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { @@ -305,14 +310,14 @@ func (s *containerServer) WriteRaw(stream pb.ContainerService_WriteRawServer) er return status.Error(codes.InvalidArgument, "first chunk must include path") } path = resolvePath(path) - if mkErr := os.MkdirAll(filepath.Dir(path), 0o755); mkErr != nil { + if mkErr := os.MkdirAll(filepath.Dir(path), 0o750); mkErr != nil { return status.Errorf(codes.Internal, "mkdir: %v", mkErr) } - f, err = os.Create(path) + f, err = os.Create(path) //nolint:gosec // G304: MCP container filesystem server; path is resolved within the container if err != nil { return status.Errorf(codes.Internal, "create: %v", err) } - defer f.Close() + defer func() { _ = f.Close() }() } if len(chunk.GetData()) > 0 { @@ -327,7 +332,7 @@ func (s *containerServer) WriteRaw(stream pb.ContainerService_WriteRawServer) er return stream.SendAndClose(&pb.WriteRawResponse{BytesWritten: written}) } -func (s *containerServer) DeleteFile(_ context.Context, req *pb.DeleteFileRequest) (*pb.DeleteFileResponse, error) { +func (*containerServer) DeleteFile(_ context.Context, req *pb.DeleteFileRequest) (*pb.DeleteFileResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") @@ -346,7 +351,7 @@ func (s *containerServer) DeleteFile(_ context.Context, req *pb.DeleteFileReques return &pb.DeleteFileResponse{}, nil } -func (s *containerServer) Stat(_ context.Context, req *pb.StatRequest) (*pb.StatResponse, error) { +func (*containerServer) Stat(_ context.Context, req *pb.StatRequest) (*pb.StatResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") @@ -371,20 +376,20 @@ func (s *containerServer) Stat(_ context.Context, req *pb.StatRequest) (*pb.Stat }, nil } -func (s *containerServer) Mkdir(_ context.Context, req *pb.MkdirRequest) (*pb.MkdirResponse, error) { +func (*containerServer) Mkdir(_ context.Context, req *pb.MkdirRequest) (*pb.MkdirResponse, error) { path := req.GetPath() if path == "" { return nil, status.Error(codes.InvalidArgument, "path is required") } path = resolvePath(path) - if err := os.MkdirAll(path, 0o755); err != nil { + if err := os.MkdirAll(path, 0o750); err != nil { return nil, status.Errorf(codes.Internal, "mkdir: %v", err) } return &pb.MkdirResponse{}, nil } -func (s *containerServer) Rename(_ context.Context, req *pb.RenameRequest) (*pb.RenameResponse, error) { +func (*containerServer) Rename(_ context.Context, req *pb.RenameRequest) (*pb.RenameResponse, error) { oldPath := req.GetOldPath() newPath := req.GetNewPath() if oldPath == "" || newPath == "" { @@ -393,7 +398,7 @@ func (s *containerServer) Rename(_ context.Context, req *pb.RenameRequest) (*pb. oldPath = resolvePath(oldPath) newPath = resolvePath(newPath) - if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(newPath), 0o750); err != nil { return nil, status.Errorf(codes.Internal, "mkdir parent: %v", err) } if err := os.Rename(oldPath, newPath); err != nil { @@ -418,7 +423,7 @@ func streamPipe(stream pb.ContainerService_ExecServer, r io.Reader, st pb.ExecOu } } -func buildFileEntry(name, fullPath string, d fs.DirEntry) (*pb.FileEntry, error) { +func buildFileEntry(name, _ string, d fs.DirEntry) (*pb.FileEntry, error) { info, err := d.Info() if err != nil { return nil, err @@ -439,10 +444,10 @@ func resolvePath(path string) string { return filepath.Join(defaultWorkDir, path) } -func truncateRunes(s string, max int) string { +func truncateRunes(s string, maxRunes int) string { pos := 0 count := 0 - for pos < len(s) && count < max { + for pos < len(s) && count < maxRunes { _, size := utf8.DecodeRuneInString(s[pos:]) pos += size count++ diff --git a/cmd/memoh/main.go b/cmd/memoh/main.go index f4df8a0c..66b65e3a 100644 --- a/cmd/memoh/main.go +++ b/cmd/memoh/main.go @@ -10,7 +10,7 @@ func main() { rootCmd := &cobra.Command{ Use: "memoh", Short: "Memoh unified binary", - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { runServe() return nil }, @@ -19,7 +19,7 @@ func main() { rootCmd.AddCommand(&cobra.Command{ Use: "serve", Short: "Start the server", - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { runServe() return nil }, @@ -29,7 +29,7 @@ func main() { Use: "migrate ", Short: "Run database migrations", Args: cobra.MinimumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, args []string) error { return runMigrate(args) }, }) @@ -37,7 +37,7 @@ func main() { rootCmd.AddCommand(&cobra.Command{ Use: "version", Short: "Print version information", - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { return runVersion() }, }) diff --git a/cmd/memoh/serve.go b/cmd/memoh/serve.go index 290b7889..f7f966c4 100644 --- a/cmd/memoh/serve.go +++ b/cmd/memoh/serve.go @@ -197,7 +197,7 @@ func provideContainerService(lc fx.Lifecycle, log *slog.Logger, cfg config.Confi if err != nil { return nil, err } - lc.Append(fx.Hook{OnStop: func(ctx context.Context) error { cleanup(); return nil }}) + lc.Append(fx.Hook{OnStop: func(_ context.Context) error { cleanup(); return nil }}) return svc, nil } @@ -206,7 +206,7 @@ func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) { if err != nil { return nil, fmt.Errorf("db connect: %w", err) } - lc.Append(fx.Hook{OnStop: func(ctx context.Context) error { conn.Close(); return nil }}) + lc.Append(fx.Hook{OnStop: func(_ context.Context) error { conn.Close(); return nil }}) return conn, nil } @@ -214,21 +214,25 @@ func provideDBQueries(conn *pgxpool.Pool) *dbsqlc.Queries { return dbsqlc.New(co func provideMCPManager(log *slog.Logger, service ctr.Service, cfg config.Config, conn *pgxpool.Pool) *mcp.Manager { return mcp.NewManager(log, service, cfg.MCP, cfg.Containerd.Namespace, conn) } + func provideAgentRuntimeManager(log *slog.Logger, cfg config.Config) *agentruntime.Manager { return agentruntime.NewManager(log, cfg) } + func provideMemoryLLM(modelsService *models.Service, queries *dbsqlc.Queries, log *slog.Logger) memprovider.LLM { return &lazyLLMClient{modelsService: modelsService, queries: queries, timeout: 30 * time.Second, logger: log} } + func provideMemoryProviderRegistry(log *slog.Logger, chatService *conversation.Service, accountService *accounts.Service, manager *mcp.Manager) *memprovider.Registry { registry := memprovider.NewRegistry(log) builtinRuntime := handlers.NewBuiltinMemoryRuntime(manager) - registry.RegisterFactory(memprovider.BuiltinType, func(id string, config map[string]any) (memprovider.Provider, error) { + registry.RegisterFactory(memprovider.BuiltinType, func(_ string, _ map[string]any) (memprovider.Provider, error) { return memprovider.NewBuiltinProvider(log, builtinRuntime, chatService, accountService), nil }) registry.Register("__builtin_default__", memprovider.NewBuiltinProvider(log, builtinRuntime, chatService, accountService)) return registry } + func startMemoryProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, mpService *memprovider.Service, registry *memprovider.Registry) { lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { @@ -246,15 +250,19 @@ func startMemoryProviderBootstrap(lc fx.Lifecycle, log *slog.Logger, mpService * }, }) } + func provideRouteService(log *slog.Logger, queries *dbsqlc.Queries, chatService *conversation.Service) *route.DBService { return route.NewService(log, queries, chatService) } + func provideMessageService(log *slog.Logger, queries *dbsqlc.Queries, hub *event.Hub) *message.DBService { return message.NewService(log, queries, hub) } + func provideScheduleTriggerer(resolver *flow.Resolver) schedule.Triggerer { return flow.NewScheduleGateway(resolver) } + func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *models.Service, queries *dbsqlc.Queries, chatService *conversation.Service, msgService *message.DBService, settingsService *settings.Service, mediaService *media.Service, containerdHandler *handlers.ContainerdHandler, inboxService *inbox.Service, memoryRegistry *memprovider.Registry) *flow.Resolver { resolver := flow.NewResolver(log, modelsService, queries, chatService, msgService, settingsService, cfg.AgentGateway.BaseURL(), 120*time.Second) resolver.SetMemoryRegistry(memoryRegistry) @@ -263,6 +271,7 @@ func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *mod resolver.SetInboxService(inboxService) return resolver } + func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService *media.Service) *channel.Registry { registry := channel.NewRegistry() tgAdapter := telegram.NewTelegramAdapter(log) @@ -277,6 +286,7 @@ func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub, mediaService registry.MustRegister(local.NewWebAdapter(hub)) return registry } + func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *local.RouteHub, routeService *route.DBService, msgService *message.DBService, resolver *flow.Resolver, identityService *identities.Service, botService *bots.Service, policyService *policy.Service, preauthService *preauth.Service, bindService *bind.Service, mediaService *media.Service, inboxService *inbox.Service, rc *boot.RuntimeConfig) *inbound.ChannelInboundProcessor { processor := inbound.NewChannelInboundProcessor(log, registry, routeService, msgService, resolver, identityService, botService, policyService, preauthService, bindService, rc.JwtSecret, 5*time.Minute) processor.SetMediaService(mediaService) @@ -284,6 +294,7 @@ func provideChannelRouter(log *slog.Logger, registry *channel.Registry, hub *loc processor.SetInboxService(inboxService) return processor } + func provideChannelManager(log *slog.Logger, registry *channel.Registry, channelStore *channel.Store, channelRouter *inbound.ChannelInboundProcessor) *channel.Manager { mgr := channel.NewManager(log, registry, channelStore, channelRouter) if mw := channelRouter.IdentityMiddleware(); mw != nil { @@ -291,15 +302,19 @@ func provideChannelManager(log *slog.Logger, registry *channel.Registry, channel } return mgr } + func provideChannelLifecycleService(channelStore *channel.Store, channelManager *channel.Manager) *channel.Lifecycle { return channel.NewLifecycle(channelStore, channelManager) } + func provideContainerdHandler(log *slog.Logger, service ctr.Service, manager *mcp.Manager, cfg config.Config, rc *boot.RuntimeConfig, botService *bots.Service, accountService *accounts.Service, policyService *policy.Service, queries *dbsqlc.Queries) *handlers.ContainerdHandler { return handlers.NewContainerdHandler(log, service, manager, cfg.MCP, cfg.Containerd.Namespace, rc.ContainerBackend, botService, accountService, policyService, queries) } + func provideFederationGateway(log *slog.Logger, containerdHandler *handlers.ContainerdHandler) *handlers.MCPFederationGateway { return handlers.NewMCPFederationGateway(log, containerdHandler) } + func provideOAuthService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) *mcp.OAuthService { addr := strings.TrimSpace(cfg.Server.Addr) if addr == "" { @@ -312,7 +327,8 @@ func provideOAuthService(log *slog.Logger, queries *dbsqlc.Queries, cfg config.C callbackURL := "http://" + host + "/oauth/mcp/callback" return mcp.NewOAuthService(log, queries, callbackURL) } -func provideToolGatewayService(log *slog.Logger, cfg config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, chatService *conversation.Service, accountService *accounts.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *mcp.Manager, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService, mediaService *media.Service, inboxService *inbox.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, oauthService *mcp.OAuthService) *mcp.ToolGatewayService { + +func provideToolGatewayService(log *slog.Logger, _ config.Config, channelManager *channel.Manager, registry *channel.Registry, routeService *route.DBService, scheduleService *schedule.Service, _ *conversation.Service, _ *accounts.Service, settingsService *settings.Service, searchProviderService *searchproviders.Service, manager *mcp.Manager, containerdHandler *handlers.ContainerdHandler, mcpConnService *mcp.ConnectionService, mediaService *media.Service, inboxService *inbox.Service, memoryRegistry *memprovider.Registry, emailService *emailpkg.Service, emailManager *emailpkg.Manager, fedGateway *handlers.MCPFederationGateway, oauthService *mcp.OAuthService) *mcp.ToolGatewayService { fedGateway.SetOAuthService(oauthService) var assetResolver mcpmessage.AssetResolver if mediaService != nil { @@ -331,19 +347,19 @@ func provideToolGatewayService(log *slog.Logger, cfg config.Config, channelManag containerdHandler.SetToolGatewayService(svc) return svc } -func provideMemoryHandler(log *slog.Logger, botService *bots.Service, accountService *accounts.Service, cfg config.Config, manager *mcp.Manager, memoryRegistry *memprovider.Registry, settingsService *settings.Service, containerdHandler *handlers.ContainerdHandler) *handlers.MemoryHandler { + +func provideMemoryHandler(log *slog.Logger, botService *bots.Service, accountService *accounts.Service, _ config.Config, manager *mcp.Manager, memoryRegistry *memprovider.Registry, settingsService *settings.Service, _ *handlers.ContainerdHandler) *handlers.MemoryHandler { h := handlers.NewMemoryHandler(log, botService, accountService) h.SetMemoryRegistry(memoryRegistry) h.SetSettingsService(settingsService) h.SetMCPClientProvider(manager) return h } -func provideAuthHandler(log *slog.Logger, accountService *accounts.Service, rc *boot.RuntimeConfig) *handlers.AuthHandler { - return handlers.NewAuthHandler(log, accountService, rc.JwtSecret, rc.JwtExpiresIn) -} + func provideMemohAuthHandler(log *slog.Logger, accountService *accounts.Service, rc *boot.RuntimeConfig) *memohAuthHandler { return &memohAuthHandler{inner: handlers.NewAuthHandler(log, accountService, rc.JwtSecret, rc.JwtExpiresIn)} } + func provideMessageHandler(log *slog.Logger, chatService *conversation.Service, msgService *message.DBService, mediaService *media.Service, botService *bots.Service, accountService *accounts.Service, hub *event.Hub) *handlers.MessageHandler { h := handlers.NewMessageHandler(log, chatService, msgService, botService, accountService, hub) h.SetMediaService(mediaService) @@ -356,16 +372,20 @@ func (h *memohAuthHandler) Register(e *echo.Echo) { e.POST("/api/auth/login", h.inner.Login) e.POST("/api/auth/refresh", h.inner.Refresh) } + func provideMediaService(log *slog.Logger, manager *mcp.Manager) *media.Service { provider := containerfs.New(manager) return media.NewService(log, provider) } + func provideUsersHandler(log *slog.Logger, accountService *accounts.Service, identityService *identities.Service, botService *bots.Service, routeService *route.DBService, channelStore *channel.Store, channelLifecycle *channel.Lifecycle, channelManager *channel.Manager, registry *channel.Registry) *handlers.UsersHandler { return handlers.NewUsersHandler(log, accountService, identityService, botService, routeService, channelStore, channelLifecycle, channelManager, registry) } + func provideCLIHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler { return handlers.NewLocalChannelHandler(local.CLIType, channelManager, channelStore, chatService, hub, botService, accountService) } + func provideWebHandler(channelManager *channel.Manager, channelStore *channel.Store, chatService *conversation.Service, hub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *handlers.LocalChannelHandler { return handlers.NewLocalChannelHandler(local.WebType, channelManager, channelStore, chatService, hub, botService, accountService) } @@ -480,9 +500,11 @@ func provideServer(params serverParams) *memohServer { } return &memohServer{echo: e, addr: addr} } + func startScheduleService(lc fx.Lifecycle, scheduleService *schedule.Service) { lc.Append(fx.Hook{OnStart: func(ctx context.Context) error { return scheduleService.Bootstrap(ctx) }}) } + func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager) { ctx, cancel := context.WithCancel(context.Background()) lc.Append(fx.Hook{ @@ -490,15 +512,18 @@ func startChannelManager(lc fx.Lifecycle, channelManager *channel.Manager) { OnStop: func(stopCtx context.Context) error { cancel(); return channelManager.Shutdown(stopCtx) }, }) } + func startContainerReconciliation(lc fx.Lifecycle, containerdHandler *handlers.ContainerdHandler, _ *mcp.ToolGatewayService) { lc.Append(fx.Hook{OnStart: func(ctx context.Context) error { go containerdHandler.ReconcileContainers(ctx); return nil }}) } + func startAgentRuntime(lc fx.Lifecycle, manager *agentruntime.Manager) { lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { return manager.Start(ctx) }, OnStop: func(ctx context.Context) error { return manager.Stop(ctx) }, }) } + func startServer(lc fx.Lifecycle, logger *slog.Logger, srv *memohServer, shutdowner fx.Shutdowner, cfg config.Config, queries *dbsqlc.Queries, botService *bots.Service, containerdHandler *handlers.ContainerdHandler, manager *mcp.Manager, mcpConnService *mcp.ConnectionService, toolGateway *mcp.ToolGatewayService, channelManager *channel.Manager) { fmt.Printf("Starting Memoh Agent %s\n", version.GetInfo()) lc.Append(fx.Hook{ @@ -583,18 +608,22 @@ func hasAnyPrefix(path string, prefixes []string) bool { } return false } + func provideEmailRegistry(log *slog.Logger) *emailpkg.Registry { reg := emailpkg.NewRegistry() reg.Register(emailgeneric.New(log)) reg.Register(emailmailgun.New(log)) return reg } + func provideEmailChatGateway(resolver *flow.Resolver, queries *dbsqlc.Queries, cfg config.Config, log *slog.Logger) emailpkg.ChatTriggerer { return flow.NewEmailChatGateway(resolver, queries, cfg.Auth.JWTSecret, log) } + func provideEmailTrigger(log *slog.Logger, service *emailpkg.Service, botInbox *inbox.Service, chatTriggerer emailpkg.ChatTriggerer) *emailpkg.Trigger { return emailpkg.NewTrigger(log, service, botInbox, chatTriggerer) } + func startEmailManager(lc fx.Lifecycle, emailManager *emailpkg.Manager) { ctx, cancel := context.WithCancel(context.Background()) lc.Append(fx.Hook{ @@ -606,12 +635,13 @@ func startEmailManager(lc fx.Lifecycle, emailManager *emailpkg.Manager) { }() return nil }, - OnStop: func(_ context.Context) error { cancel(); emailManager.Stop(); return nil }, + OnStop: func(stopCtx context.Context) error { cancel(); emailManager.Stop(stopCtx); return nil }, }) } + func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error { if queries == nil { - return fmt.Errorf("db queries not configured") + return errors.New("db queries not configured") } count, err := queries.CountAccounts(ctx) if err != nil { @@ -624,7 +654,7 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer password := strings.TrimSpace(cfg.Admin.Password) email := strings.TrimSpace(cfg.Admin.Email) if username == "" || password == "" { - return fmt.Errorf("admin username/password required in config.toml") + return errors.New("admin username/password required in config.toml") } if password == "change-your-password-here" { log.Warn("admin password uses default placeholder; please update config.toml") @@ -669,6 +699,7 @@ func (c *lazyLLMClient) Extract(ctx context.Context, req memprovider.ExtractRequ } return client.Extract(ctx, req) } + func (c *lazyLLMClient) Decide(ctx context.Context, req memprovider.DecideRequest) (memprovider.DecideResponse, error) { client, err := c.resolve(ctx) if err != nil { @@ -676,6 +707,7 @@ func (c *lazyLLMClient) Decide(ctx context.Context, req memprovider.DecideReques } return client.Decide(ctx, req) } + func (c *lazyLLMClient) Compact(ctx context.Context, req memprovider.CompactRequest) (memprovider.CompactResponse, error) { client, err := c.resolve(ctx) if err != nil { @@ -683,6 +715,7 @@ func (c *lazyLLMClient) Compact(ctx context.Context, req memprovider.CompactRequ } return client.Compact(ctx, req) } + func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string, error) { client, err := c.resolve(ctx) if err != nil { @@ -690,9 +723,10 @@ func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string } return client.DetectLanguage(ctx, text) } + func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) { if c.modelsService == nil || c.queries == nil { - return nil, fmt.Errorf("models service not configured") + return nil, errors.New("models service not configured") } botID := "" memoryModel, memoryProvider, err := models.SelectMemoryModelForBot(ctx, c.modelsService, c.queries, botID) @@ -707,7 +741,7 @@ func (c *lazyLLMClient) resolve(ctx context.Context) (memprovider.LLM, error) { } _ = memoryProvider _ = memoryModel - return nil, fmt.Errorf("memory llm runtime is not available") + return nil, errors.New("memory llm runtime is not available") } type skillLoaderAdapter struct{ handler *handlers.ContainerdHandler } @@ -728,7 +762,7 @@ type mediaAssetResolverAdapter struct{ media *media.Service } func (a *mediaAssetResolverAdapter) GetByStorageKey(ctx context.Context, botID, storageKey string) (mcpmessage.AssetMeta, error) { if a == nil || a.media == nil { - return mcpmessage.AssetMeta{}, fmt.Errorf("media service not configured") + return mcpmessage.AssetMeta{}, errors.New("media service not configured") } asset, err := a.media.GetByStorageKey(ctx, botID, storageKey) if err != nil { @@ -736,9 +770,10 @@ func (a *mediaAssetResolverAdapter) GetByStorageKey(ctx context.Context, botID, } return mcpmessage.AssetMeta{ContentHash: asset.ContentHash, Mime: asset.Mime, SizeBytes: asset.SizeBytes, StorageKey: asset.StorageKey}, nil } + func (a *mediaAssetResolverAdapter) IngestContainerFile(ctx context.Context, botID, containerPath string) (mcpmessage.AssetMeta, error) { if a == nil || a.media == nil { - return mcpmessage.AssetMeta{}, fmt.Errorf("media service not configured") + return mcpmessage.AssetMeta{}, errors.New("media service not configured") } asset, err := a.media.IngestContainerFile(ctx, botID, containerPath) if err != nil { @@ -751,7 +786,7 @@ type gatewayAssetLoaderAdapter struct{ media *media.Service } func (a *gatewayAssetLoaderAdapter) OpenForGateway(ctx context.Context, botID, contentHash string) (io.ReadCloser, string, error) { if a == nil || a.media == nil { - return nil, "", fmt.Errorf("media service not configured") + return nil, "", errors.New("media service not configured") } reader, asset, err := a.media.Open(ctx, botID, contentHash) if err != nil { diff --git a/internal/accounts/service.go b/internal/accounts/service.go index 90aea315..28116b6d 100644 --- a/internal/accounts/service.go +++ b/internal/accounts/service.go @@ -42,7 +42,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { // Get returns an account by user id. func (s *Service) Get(ctx context.Context, userID string) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -58,7 +58,7 @@ func (s *Service) Get(ctx context.Context, userID string) (Account, error) { // Login authenticates by identity (username or email) and password. func (s *Service) Login(ctx context.Context, identity, password string) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } identity = strings.TrimSpace(identity) if identity == "" || strings.TrimSpace(password) == "" { @@ -91,7 +91,7 @@ func (s *Service) Login(ctx context.Context, identity, password string) (Account // ListAccounts returns all accounts. func (s *Service) ListAccounts(ctx context.Context) ([]Account, error) { if s.queries == nil { - return nil, fmt.Errorf("account queries not configured") + return nil, errors.New("account queries not configured") } rows, err := s.queries.ListAccounts(ctx) if err != nil { @@ -107,7 +107,7 @@ func (s *Service) ListAccounts(ctx context.Context) ([]Account, error) { // IsAdmin checks if the user has admin role. func (s *Service) IsAdmin(ctx context.Context, userID string) (bool, error) { if s.queries == nil { - return false, fmt.Errorf("account queries not configured") + return false, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -126,15 +126,15 @@ func (s *Service) IsAdmin(ctx context.Context, userID string) (bool, error) { // Create creates a new account for an existing user. func (s *Service) Create(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } username := strings.TrimSpace(req.Username) if username == "" { - return Account{}, fmt.Errorf("username is required") + return Account{}, errors.New("username is required") } password := strings.TrimSpace(req.Password) if password == "" { - return Account{}, fmt.Errorf("password is required") + return Account{}, errors.New("password is required") } role, err := normalizeRole(req.Role) if err != nil { @@ -195,7 +195,7 @@ func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAcco userID = strings.TrimSpace(userID) if userID == "" { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } userRow, err := s.queries.CreateUser(ctx, sqlc.CreateUserParams{ IsActive: true, @@ -205,7 +205,7 @@ func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAcco return Account{}, err } if !userRow.ID.Valid { - return Account{}, fmt.Errorf("create user: invalid id") + return Account{}, errors.New("create user: invalid id") } userID = userRow.ID.String() } @@ -215,7 +215,7 @@ func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAcco // UpdateAdmin updates account fields as admin. func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAccountRequest) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -264,7 +264,7 @@ func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAcco // UpdateProfile updates the user's profile. func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdateProfileRequest) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -300,10 +300,10 @@ func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdatePr // UpdatePassword changes the password after verifying the current one. func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, newPassword string) error { if s.queries == nil { - return fmt.Errorf("account queries not configured") + return errors.New("account queries not configured") } if strings.TrimSpace(newPassword) == "" { - return fmt.Errorf("new password is required") + return errors.New("new password is required") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -336,10 +336,10 @@ func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, n // ResetPassword sets a new password without requiring the current one. func (s *Service) ResetPassword(ctx context.Context, userID, newPassword string) error { if s.queries == nil { - return fmt.Errorf("account queries not configured") + return errors.New("account queries not configured") } if strings.TrimSpace(newPassword) == "" { - return fmt.Errorf("new password is required") + return errors.New("new password is required") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -423,4 +423,3 @@ func toAccount(row sqlc.User) Account { LastLoginAt: lastLogin, } } - diff --git a/internal/accounts/types.go b/internal/accounts/types.go index 7a3b4f62..2dbc553a 100644 --- a/internal/accounts/types.go +++ b/internal/accounts/types.go @@ -19,7 +19,7 @@ type Account struct { // CreateAccountRequest is the input for creating an account. type CreateAccountRequest struct { Username string `json:"username"` - Password string `json:"password"` + Password string `json:"password"` //nolint:gosec // intentional: JSON request field carrying a user-supplied credential Email string `json:"email,omitempty"` Role string `json:"role,omitempty"` DisplayName string `json:"display_name,omitempty"` diff --git a/internal/attachment/normalize.go b/internal/attachment/normalize.go index ee462f33..d20e822b 100644 --- a/internal/attachment/normalize.go +++ b/internal/attachment/normalize.go @@ -3,6 +3,7 @@ package attachment import ( "bytes" "encoding/base64" + "errors" "fmt" "io" "net/http" @@ -91,7 +92,7 @@ func ResolveMime(mediaType media.MediaType, sourceMime, sniffedMime string) stri // PrepareReaderAndMime reads a small prefix for MIME sniffing and replays it. func PrepareReaderAndMime(reader io.Reader, mediaType media.MediaType, sourceMime string) (io.Reader, string, error) { if reader == nil { - return nil, "", fmt.Errorf("reader is required") + return nil, "", errors.New("reader is required") } header := make([]byte, 512) n, err := reader.Read(header) @@ -128,7 +129,7 @@ func NormalizeBase64DataURL(input, mime string) string { func DecodeBase64(input string, maxBytes int64) (io.Reader, error) { value := strings.TrimSpace(input) if value == "" { - return nil, fmt.Errorf("base64 payload is empty") + return nil, errors.New("base64 payload is empty") } if strings.HasPrefix(strings.ToLower(value), "data:") { if idx := strings.Index(value, ","); idx >= 0 { diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 340cab23..bd23c7f4 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "fmt" "net/http" "strings" @@ -30,7 +31,7 @@ func JWTMiddleware(secret string, skipper middleware.Skipper) echo.MiddlewareFun SigningMethod: "HS256", TokenLookup: "header:Authorization:Bearer ,query:token", Skipper: skipper, - NewClaimsFunc: func(c echo.Context) jwt.Claims { + NewClaimsFunc: func(_ echo.Context) jwt.Claims { return jwt.MapClaims{} }, }) @@ -58,13 +59,13 @@ func UserIDFromContext(c echo.Context) (string, error) { // GenerateToken creates a signed JWT for the user. func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(userID) == "" { - return "", time.Time{}, fmt.Errorf("user id is required") + return "", time.Time{}, errors.New("user id is required") } if strings.TrimSpace(secret) == "" { - return "", time.Time{}, fmt.Errorf("jwt secret is required") + return "", time.Time{}, errors.New("jwt secret is required") } if expiresIn <= 0 { - return "", time.Time{}, fmt.Errorf("jwt expires in must be positive") + return "", time.Time{}, errors.New("jwt expires in must be positive") } now := time.Now().UTC() @@ -95,22 +96,22 @@ type ChatToken struct { // GenerateChatToken creates a signed JWT for chat route reply. func GenerateChatToken(info ChatToken, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(info.BotID) == "" { - return "", time.Time{}, fmt.Errorf("bot id is required") + return "", time.Time{}, errors.New("bot id is required") } if strings.TrimSpace(info.ChatID) == "" { - return "", time.Time{}, fmt.Errorf("chat id is required") + return "", time.Time{}, errors.New("chat id is required") } if strings.TrimSpace(info.UserID) == "" { info.UserID = strings.TrimSpace(info.ChannelIdentityID) } if strings.TrimSpace(info.UserID) == "" { - return "", time.Time{}, fmt.Errorf("user id is required") + return "", time.Time{}, errors.New("user id is required") } if strings.TrimSpace(secret) == "" { - return "", time.Time{}, fmt.Errorf("jwt secret is required") + return "", time.Time{}, errors.New("jwt secret is required") } if expiresIn <= 0 { - return "", time.Time{}, fmt.Errorf("jwt expires in must be positive") + return "", time.Time{}, errors.New("jwt expires in must be positive") } now := time.Now().UTC() diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go index 2b48f6eb..f59c3ffd 100644 --- a/internal/auth/jwt_test.go +++ b/internal/auth/jwt_test.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -9,6 +10,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRefreshTokenFromContext(t *testing.T) { @@ -23,13 +25,13 @@ func TestRefreshTokenFromContext(t *testing.T) { // Create an initial token with a 5-minute lifespan initialDuration := 5 * time.Minute initialTokenStr, _, err := GenerateToken(userID, secret, initialDuration) - assert.NoError(t, err) + require.NoError(t, err) // Parse the token to place it into the echo context - token, err := jwt.Parse(initialTokenStr, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.Parse(initialTokenStr, func(_ *jwt.Token) (interface{}, error) { return []byte(secret), nil }) - assert.NoError(t, err) + require.NoError(t, err) c.Set("user", token) // Simulate some time passing to ensure the new token has a different 'iat' and 'exp' @@ -38,7 +40,7 @@ func TestRefreshTokenFromContext(t *testing.T) { // Run the refresh function defaultDuration := 1 * time.Hour newTokenStr, newExpiresAt, err := RefreshTokenFromContext(c, secret, defaultDuration) - assert.NoError(t, err) + require.NoError(t, err) assert.NotEmpty(t, newTokenStr) // Parse the original token claims for comparison @@ -47,10 +49,10 @@ func TestRefreshTokenFromContext(t *testing.T) { origIat := int64(originalClaims["iat"].(float64)) // Parse the new token - newToken, err := jwt.Parse(newTokenStr, func(token *jwt.Token) (interface{}, error) { + newToken, err := jwt.Parse(newTokenStr, func(_ *jwt.Token) (interface{}, error) { return []byte(secret), nil }) - assert.NoError(t, err) + require.NoError(t, err) assert.True(t, newToken.Valid) newClaims, ok := newToken.Claims.(jwt.MapClaims) @@ -69,11 +71,11 @@ func TestRefreshTokenFromContext(t *testing.T) { // 2. Ensure the refreshed token has a positive lifetime and does not exceed the configured default duration lifetimeSeconds := newExp - newIat - assert.Greater(t, lifetimeSeconds, int64(0)) + assert.Positive(t, lifetimeSeconds) assert.LessOrEqual(t, lifetimeSeconds, int64(defaultDuration.Seconds())) // 3. Ensure the return value matches the claim - assert.Equal(t, newExpiresAt.Unix(), newExp) + assert.Equal(t, newExp, newExpiresAt.Unix()) } func TestRefreshTokenFromContext_MissingUser(t *testing.T) { @@ -87,9 +89,10 @@ func TestRefreshTokenFromContext_MissingUser(t *testing.T) { // Context without the "user" key _, _, err := RefreshTokenFromContext(c, secret, defaultDuration) - assert.Error(t, err) + require.Error(t, err) - httpErr, ok := err.(*echo.HTTPError) + httpErr := &echo.HTTPError{} + ok := errors.As(err, &httpErr) assert.True(t, ok) assert.Equal(t, http.StatusUnauthorized, httpErr.Code) assert.Equal(t, "invalid token", httpErr.Message) diff --git a/internal/bind/service.go b/internal/bind/service.go index a0c84188..2b16e245 100644 --- a/internal/bind/service.go +++ b/internal/bind/service.go @@ -46,7 +46,7 @@ func NewService(log *slog.Logger, pool *pgxpool.Pool, queries *sqlc.Queries) *Se // Platform is optional; when provided, bind consume must happen on the same channel platform. func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, ttl time.Duration) (Code, error) { if s.queries == nil { - return Code{}, fmt.Errorf("bind queries not configured") + return Code{}, errors.New("bind queries not configured") } if ttl <= 0 { ttl = defaultTTL @@ -78,13 +78,13 @@ func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, tt } return Code{}, fmt.Errorf("create bind code: %w", err) } - return Code{}, fmt.Errorf("create bind code: token collision after retries") + return Code{}, errors.New("create bind code: token collision after retries") } // Get looks up a bind code by token. func (s *Service) Get(ctx context.Context, token string) (Code, error) { if s.queries == nil { - return Code{}, fmt.Errorf("bind queries not configured") + return Code{}, errors.New("bind queries not configured") } row, err := s.queries.GetBindCode(ctx, strings.TrimSpace(token)) if err != nil { @@ -99,7 +99,7 @@ func (s *Service) Get(ctx context.Context, token string) (Code, error) { // Consume validates and consumes a bind code and links the channel identity to issuer user. func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID string) error { if s.queries == nil || s.pool == nil { - return fmt.Errorf("bind service not configured") + return errors.New("bind service not configured") } // Fast-fail based on caller snapshot before opening a transaction. @@ -115,7 +115,7 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri } sourceIdentityID := strings.TrimSpace(channelIdentityID) if sourceIdentityID == "" { - return fmt.Errorf("channel identity id is required") + return errors.New("channel identity id is required") } pgSourceIdentityID, err := db.ParseUUID(sourceIdentityID) if err != nil { @@ -149,7 +149,7 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri targetUserID := strings.TrimSpace(lockedCode.IssuedByUserID) if targetUserID == "" { - return fmt.Errorf("bind code issuer user is missing") + return errors.New("bind code issuer user is missing") } pgTargetUserID, err := db.ParseUUID(targetUserID) if err != nil { @@ -158,14 +158,14 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri if _, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID); err != nil { if errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("channel identity not found") + return errors.New("channel identity not found") } return fmt.Errorf("lock source identity: %w", err) } sourceIdentity, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("channel identity not found") + return errors.New("channel identity not found") } return fmt.Errorf("reload source identity: %w", err) } diff --git a/internal/bind/service_consume_integration_test.go b/internal/bind/service_consume_integration_test.go index 735eddb9..c5bad185 100644 --- a/internal/bind/service_consume_integration_test.go +++ b/internal/bind/service_consume_integration_test.go @@ -2,7 +2,6 @@ package bind_test import ( "context" - "encoding/json" "errors" "fmt" "log/slog" @@ -10,12 +9,10 @@ import ( "testing" "time" - "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel/identities" - "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -55,28 +52,6 @@ func createUserForBind(ctx context.Context, queries *sqlc.Queries) (string, erro return row.ID.String(), nil } -func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { - pgOwnerID, err := db.ParseUUID(ownerUserID) - if err != nil { - return "", err - } - meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"}) - if err != nil { - return "", err - } - row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ - OwnerUserID: pgOwnerID, - Type: "personal", - DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, - IsActive: true, - Metadata: meta, - }) - if err != nil { - return "", err - } - return row.ID.String(), nil -} - func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { queries, channelIdentitySvc, bindSvc, cleanup := setupBindConsumeIntegrationTest(t) defer cleanup() diff --git a/internal/bind/service_test.go b/internal/bind/service_test.go index 298a744b..a9d8c4e4 100644 --- a/internal/bind/service_test.go +++ b/internal/bind/service_test.go @@ -122,13 +122,13 @@ func TestToCode_OptionalFields(t *testing.T) { } now := time.Now().UTC() row := sqlc.ChannelIdentityBindCode{ - ID: pgID, - Token: "TOKEN", - IssuedByUserID: pgID, - ChannelType: pgtype.Text{Valid: false}, - ExpiresAt: pgtype.Timestamptz{Valid: false}, - UsedAt: pgtype.Timestamptz{Valid: false}, - CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, + ID: pgID, + Token: "TOKEN", + IssuedByUserID: pgID, + ChannelType: pgtype.Text{Valid: false}, + ExpiresAt: pgtype.Timestamptz{Valid: false}, + UsedAt: pgtype.Timestamptz{Valid: false}, + CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, } c := toCode(row) if c.Platform != "" { @@ -204,4 +204,3 @@ func TestService_Consume_InvalidChannelIdentityID(t *testing.T) { t.Fatal("expected error for invalid channel identity id") } } - diff --git a/internal/boot/runtime.go b/internal/boot/runtime.go index 7f6ef8a3..4887d29d 100644 --- a/internal/boot/runtime.go +++ b/internal/boot/runtime.go @@ -12,7 +12,7 @@ import ( ) type RuntimeConfig struct { - JwtSecret string + JwtSecret string `json:"-"` JwtExpiresIn time.Duration ServerAddr string ContainerdSocketPath string diff --git a/internal/bots/service.go b/internal/bots/service.go index da852385..1ae0b418 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -74,7 +74,7 @@ func (s *Service) AddRuntimeChecker(c RuntimeChecker) { // AuthorizeAccess checks whether userID may access the given bot. func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } bot, err := s.Get(ctx, botID) if err != nil { @@ -102,11 +102,11 @@ func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isA // Create creates a new bot owned by owner user. func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotRequest) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } ownerID := strings.TrimSpace(ownerUserID) if ownerID == "" { - return Bot{}, fmt.Errorf("owner user id is required") + return Bot{}, errors.New("owner user id is required") } ownerUUID, err := db.ParseUUID(ownerID) if err != nil { @@ -155,14 +155,14 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR if err := s.attachCheckSummary(ctx, &bot, asSQLCBot(row)); err != nil { return Bot{}, err } - s.enqueueCreateLifecycle(bot.ID) + s.enqueueCreateLifecycle(ctx, bot.ID) return bot, nil } // Get returns a bot by its ID. func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -185,7 +185,7 @@ func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { // ListByOwner returns bots owned by the given user. func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } ownerUUID, err := db.ParseUUID(ownerUserID) if err != nil { @@ -212,7 +212,7 @@ func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, e // ListByMember returns bots where the user is a member. func (s *Service) ListByMember(ctx context.Context, channelIdentityID string) ([]Bot, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -265,7 +265,7 @@ func (s *Service) ListAccessible(ctx context.Context, channelIdentityID string) // Update updates bot profile fields. func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -324,7 +324,7 @@ func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest // TransferOwner transfers bot ownership to another user. func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID string) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -357,7 +357,7 @@ func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID s // Delete removes a bot and its associated resources. func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -376,14 +376,14 @@ func (s *Service) Delete(ctx context.Context, botID string) error { }); err != nil { return err } - s.enqueueDeleteLifecycle(botID) + s.enqueueDeleteLifecycle(ctx, botID) return nil } // ListChecks evaluates runtime resource checks for a bot. func (s *Service) ListChecks(ctx context.Context, botID string) ([]BotCheck, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -396,13 +396,13 @@ func (s *Service) ListChecks(ctx context.Context, botID string) ([]BotCheck, err return s.buildRuntimeChecks(ctx, asSQLCBot(row), true) } -func (s *Service) enqueueCreateLifecycle(botID string) { +func (s *Service) enqueueCreateLifecycle(ctx context.Context, botID string) { go func() { - ctx, cancel := context.WithTimeout(context.Background(), botLifecycleOperationTimeout) + lifecycleCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), botLifecycleOperationTimeout) defer cancel() if s.containerLifecycle != nil { - if err := s.containerLifecycle.SetupBotContainer(ctx, botID); err != nil { + if err := s.containerLifecycle.SetupBotContainer(lifecycleCtx, botID); err != nil { s.logger.Error("bot container setup failed", slog.String("bot_id", botID), slog.Any("error", err), @@ -410,7 +410,7 @@ func (s *Service) enqueueCreateLifecycle(botID string) { } } - if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil { + if err := s.updateStatus(lifecycleCtx, botID, BotStatusReady); err != nil { s.logger.Error("failed to update bot status to ready after create", slog.String("bot_id", botID), slog.Any("error", err), @@ -419,13 +419,13 @@ func (s *Service) enqueueCreateLifecycle(botID string) { }() } -func (s *Service) enqueueDeleteLifecycle(botID string) { +func (s *Service) enqueueDeleteLifecycle(ctx context.Context, botID string) { go func() { - ctx, cancel := context.WithTimeout(context.Background(), botLifecycleOperationTimeout) + lifecycleCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), botLifecycleOperationTimeout) defer cancel() if s.containerLifecycle != nil { - if err := s.containerLifecycle.CleanupBotContainer(ctx, botID); err != nil { + if err := s.containerLifecycle.CleanupBotContainer(lifecycleCtx, botID); err != nil { s.logger.Error("bot container cleanup failed", slog.String("bot_id", botID), slog.Any("error", err), @@ -439,17 +439,17 @@ func (s *Service) enqueueDeleteLifecycle(botID string) { slog.String("bot_id", botID), slog.Any("error", err), ) - if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil { + if err := s.updateStatus(lifecycleCtx, botID, BotStatusReady); err != nil { s.logger.Error("revert bot status failed", slog.String("bot_id", botID), slog.Any("error", err)) } return } - if err := s.queries.DeleteBotByID(ctx, botUUID); err != nil { + if err := s.queries.DeleteBotByID(lifecycleCtx, botUUID); err != nil { s.logger.Error("failed to delete bot after cleanup", slog.String("bot_id", botID), slog.Any("error", err), ) - if err := s.updateStatus(ctx, botID, BotStatusReady); err != nil { + if err := s.updateStatus(lifecycleCtx, botID, BotStatusReady); err != nil { s.logger.Error("revert bot status failed", slog.String("bot_id", botID), slog.Any("error", err)) } return @@ -459,7 +459,7 @@ func (s *Service) enqueueDeleteLifecycle(botID string) { func (s *Service) updateStatus(ctx context.Context, botID, status string) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -473,7 +473,7 @@ func (s *Service) updateStatus(ctx context.Context, botID, status string) error func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } _, err := s.queries.GetUserByID(ctx, userID) if err != nil { @@ -488,7 +488,7 @@ func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) erro // UpsertMember creates or updates a bot membership. func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemberRequest) (BotMember, error) { if s.queries == nil { - return BotMember{}, fmt.Errorf("bot queries not configured") + return BotMember{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -516,7 +516,7 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb // ListMembers returns all members of a bot. func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -536,7 +536,7 @@ func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, e // GetMember returns a specific bot member. func (s *Service) GetMember(ctx context.Context, botID, channelIdentityID string) (BotMember, error) { if s.queries == nil { - return BotMember{}, fmt.Errorf("bot queries not configured") + return BotMember{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -559,7 +559,7 @@ func (s *Service) GetMember(ctx context.Context, botID, channelIdentityID string // DeleteMember removes a member from a bot. func (s *Service) DeleteMember(ctx context.Context, botID, channelIdentityID string) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -830,14 +830,14 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot, includeD Summary: "Container task state is unknown.", Detail: "Task state cannot be determined without a container record.", }) - checks = append(checks, BotCheck{ - ID: BotCheckTypeContainerData, - Type: BotCheckTypeContainerData, - TitleKey: "bots.checks.titles.containerDataPath", - Status: BotCheckStatusUnknown, - Summary: "Container reachability is unknown.", - Detail: "Reachability cannot be determined without a container record.", - }) + checks = append(checks, BotCheck{ + ID: BotCheckTypeContainerData, + Type: BotCheckTypeContainerData, + TitleKey: "bots.checks.titles.containerDataPath", + Status: BotCheckStatusUnknown, + Summary: "Container reachability is unknown.", + Detail: "Reachability cannot be determined without a container record.", + }) if includeDynamic { checks = s.appendDynamicChecks(ctx, row.ID.String(), checks) } diff --git a/internal/bots/service_test.go b/internal/bots/service_test.go index edeee832..0298ace2 100644 --- a/internal/bots/service_test.go +++ b/internal/bots/service_test.go @@ -25,11 +25,11 @@ type fakeDBTX struct { queryRowFunc func(ctx context.Context, sql string, args ...any) pgx.Row } -func (d *fakeDBTX) Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) { +func (*fakeDBTX) Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) { return pgconn.CommandTag{}, nil } -func (d *fakeDBTX) Query(context.Context, string, ...interface{}) (pgx.Rows, error) { +func (*fakeDBTX) Query(context.Context, string, ...interface{}) (pgx.Rows, error) { return nil, nil } @@ -37,14 +37,14 @@ func (d *fakeDBTX) QueryRow(ctx context.Context, sql string, args ...any) pgx.Ro if d.queryRowFunc != nil { return d.queryRowFunc(ctx, sql, args...) } - return &fakeRow{scanFunc: func(dest ...any) error { return pgx.ErrNoRows }} + return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }} } // makeBotRow creates a fakeRow that populates a sqlc.Bot via Scan. // Column order: id, owner_user_id, type, display_name, avatar_url, is_active, status, // max_context_load_time, max_context_tokens, max_inbox_items, language, allow_guest, // reasoning_enabled, reasoning_effort, chat_model_id, search_provider_id, memory_provider_id, -// heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at +// heartbeat_enabled, heartbeat_interval, heartbeat_prompt, metadata, created_at, updated_at. func makeBotRow(botID, ownerUserID pgtype.UUID, botType string, allowGuest bool) *fakeRow { return &fakeRow{ scanFunc: func(dest ...any) error { @@ -63,8 +63,8 @@ func makeBotRow(botID, ownerUserID pgtype.UUID, botType string, allowGuest bool) *dest[9].(*int32) = 10 // MaxInboxItems *dest[10].(*string) = "en" *dest[11].(*bool) = allowGuest - *dest[12].(*bool) = false // ReasoningEnabled - *dest[13].(*string) = "medium" // ReasoningEffort + *dest[12].(*bool) = false // ReasoningEnabled + *dest[13].(*string) = "medium" // ReasoningEffort *dest[14].(*pgtype.UUID) = pgtype.UUID{} // ChatModelID *dest[15].(*pgtype.UUID) = pgtype.UUID{} // SearchProviderID *dest[16].(*pgtype.UUID) = pgtype.UUID{} // MemoryProviderID @@ -95,7 +95,7 @@ func makeMemberRow(botID, userID pgtype.UUID) *fakeRow { } func makeNoRow() *fakeRow { - return &fakeRow{scanFunc: func(dest ...any) error { return pgx.ErrNoRows }} + return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }} } func mustParseUUID(s string) pgtype.UUID { @@ -198,7 +198,7 @@ func TestAuthorizeAccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := &fakeDBTX{ - queryRowFunc: func(_ context.Context, sql string, args ...any) pgx.Row { + queryRowFunc: func(_ context.Context, _ string, args ...any) pgx.Row { // Route to bot or member row based on query. if len(args) == 1 { return makeBotRow(botUUID, ownerUUID, tt.botType, tt.allowGst) @@ -219,10 +219,8 @@ func TestAuthorizeAccess(t *testing.T) { if tt.wantErrIs != nil && err.Error() != tt.wantErrIs.Error() { t.Fatalf("expected error %q, got %q", tt.wantErrIs, err) } - } else { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + } else if err != nil { + t.Fatalf("unexpected error: %v", err) } }) } diff --git a/internal/bun/runtime/manager.go b/internal/bun/runtime/manager.go index 3da8d517..e858e4e5 100644 --- a/internal/bun/runtime/manager.go +++ b/internal/bun/runtime/manager.go @@ -88,7 +88,7 @@ func (m *Manager) Start(ctx context.Context) error { } return fmt.Errorf("agent binary missing: %w", err) } - if err := os.Chmod(agentBinPath, 0o755); err != nil { + if err := os.Chmod(agentBinPath, 0o755); err != nil { //nolint:gosec // G302: executable binary requires execute bit; 0600 would make it non-executable return fmt.Errorf("chmod agent binary: %w", err) } agentConfigPath := filepath.Join(agentDir, agentConfigFileName) @@ -96,7 +96,7 @@ func (m *Manager) Start(ctx context.Context) error { return err } - cmd := exec.Command(agentBinPath) + cmd := exec.CommandContext(ctx, agentBinPath) //nolint:gosec // G204: path is constructed internally from an embedded asset, not user input cmd.Dir = agentDir cmd.Env = append( os.Environ(), @@ -157,7 +157,7 @@ func (m *Manager) waitHealthy(ctx context.Context) error { deadline := time.Now().Add(healthCheckTimeout) for time.Now().Before(deadline) { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: URL is constructed from operator-configured host/port, not from user input if err == nil { _ = resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { @@ -174,7 +174,7 @@ func (m *Manager) address() string { } func extractFS(src fs.FS, targetDir string) error { - if err := os.MkdirAll(targetDir, 0o755); err != nil { + if err := os.MkdirAll(targetDir, 0o750); err != nil { return err } return fs.WalkDir(src, ".", func(path string, d fs.DirEntry, err error) error { @@ -186,19 +186,19 @@ func extractFS(src fs.FS, targetDir string) error { } target := filepath.Join(targetDir, path) if d.IsDir() { - return os.MkdirAll(target, 0o755) + return os.MkdirAll(target, 0o750) } - if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil { return err } r, err := src.Open(path) if err != nil { return err } - defer r.Close() + defer func() { _ = r.Close() }() - w, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) + w, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o600) //nolint:gosec // G304: target is derived from an embedded FS walk within a process-owned temp dir if err != nil { return err } @@ -211,14 +211,14 @@ func extractFS(src fs.FS, targetDir string) error { } func writeAgentConfig(path string, cfg config.Config) error { - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { return err } - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec // G304: path is constructed internally from a process-owned temp dir if err != nil { return fmt.Errorf("create agent config: %w", err) } - defer f.Close() + defer func() { _ = f.Close() }() return toml.NewEncoder(f).Encode(cfg) } @@ -231,7 +231,7 @@ func (w *logWriter) Write(p []byte) (n int, err error) { msg := string(p) msg = trimTrailingNewline(msg) if msg != "" { - w.log.Log(context.Background(), w.level, msg) + w.log.LogAttrs(context.Background(), w.level, "runtime process output", slog.String("detail", msg)) } return len(p), nil } diff --git a/internal/channel/adapters/discord/config.go b/internal/channel/adapters/discord/config.go index 1d0c4b2f..1f06dc4e 100644 --- a/internal/channel/adapters/discord/config.go +++ b/internal/channel/adapters/discord/config.go @@ -1,131 +1,131 @@ package discord import ( - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" ) type Config struct { - BotToken string + BotToken string } type UserConfig struct { - UserID string - ChannelID string - GuildID string - Username string + UserID string + ChannelID string + GuildID string + Username string } func normalizeConfig(raw map[string]any) (map[string]any, error) { - cfg, err := parseConfig(raw) - if err != nil { - return nil, err - } - return map[string]any{"botToken": cfg.BotToken}, nil + cfg, err := parseConfig(raw) + if err != nil { + return nil, err + } + return map[string]any{"botToken": cfg.BotToken}, nil } func normalizeUserConfig(raw map[string]any) (map[string]any, error) { - cfg, err := parseUserConfig(raw) - if err != nil { - return nil, err - } - result := map[string]any{} - if cfg.UserID != "" { - result["user_id"] = cfg.UserID - } - if cfg.ChannelID != "" { - result["channel_id"] = cfg.ChannelID - } - if cfg.GuildID != "" { - result["guild_id"] = cfg.GuildID - } - if cfg.Username != "" { - result["username"] = cfg.Username - } - return result, nil + cfg, err := parseUserConfig(raw) + if err != nil { + return nil, err + } + result := map[string]any{} + if cfg.UserID != "" { + result["user_id"] = cfg.UserID + } + if cfg.ChannelID != "" { + result["channel_id"] = cfg.ChannelID + } + if cfg.GuildID != "" { + result["guild_id"] = cfg.GuildID + } + if cfg.Username != "" { + result["username"] = cfg.Username + } + return result, nil } func resolveTarget(raw map[string]any) (string, error) { - cfg, err := parseUserConfig(raw) - if err != nil { - return "", err - } - if cfg.ChannelID != "" { - return cfg.ChannelID, nil - } - if cfg.UserID != "" { - return cfg.UserID, nil - } - return "", fmt.Errorf("discord binding is incomplete") + cfg, err := parseUserConfig(raw) + if err != nil { + return "", err + } + if cfg.ChannelID != "" { + return cfg.ChannelID, nil + } + if cfg.UserID != "" { + return cfg.UserID, nil + } + return "", errors.New("discord binding is incomplete") } func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { - cfg, err := parseUserConfig(raw) - if err != nil { - return false - } - if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { - return true - } - if value := strings.TrimSpace(criteria.Attribute("channel_id")); value != "" && value == cfg.ChannelID { - return true - } - if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) { - return true - } - if criteria.SubjectID != "" { - if criteria.SubjectID == cfg.UserID || criteria.SubjectID == cfg.ChannelID { - return true - } - } - return false + cfg, err := parseUserConfig(raw) + if err != nil { + return false + } + if value := strings.TrimSpace(criteria.Attribute("user_id")); value != "" && value == cfg.UserID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("channel_id")); value != "" && value == cfg.ChannelID { + return true + } + if value := strings.TrimSpace(criteria.Attribute("username")); value != "" && strings.EqualFold(value, cfg.Username) { + return true + } + if criteria.SubjectID != "" { + if criteria.SubjectID == cfg.UserID || criteria.SubjectID == cfg.ChannelID { + return true + } + } + return false } func buildUserConfig(identity channel.Identity) map[string]any { - result := map[string]any{} - if value := strings.TrimSpace(identity.Attribute("user_id")); value != "" { - result["user_id"] = value - } - if value := strings.TrimSpace(identity.Attribute("channel_id")); value != "" { - result["channel_id"] = value - } - if value := strings.TrimSpace(identity.Attribute("guild_id")); value != "" { - result["guild_id"] = value - } - if value := strings.TrimSpace(identity.Attribute("username")); value != "" { - result["username"] = value - } - return result + result := map[string]any{} + if value := strings.TrimSpace(identity.Attribute("user_id")); value != "" { + result["user_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("channel_id")); value != "" { + result["channel_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("guild_id")); value != "" { + result["guild_id"] = value + } + if value := strings.TrimSpace(identity.Attribute("username")); value != "" { + result["username"] = value + } + return result } func parseConfig(raw map[string]any) (Config, error) { - token := strings.TrimSpace(channel.ReadString(raw, "botToken", "bot_token")) - if token == "" { - return Config{}, fmt.Errorf("discord botToken is required") - } - return Config{BotToken: token}, nil + token := strings.TrimSpace(channel.ReadString(raw, "botToken", "bot_token")) + if token == "" { + return Config{}, errors.New("discord botToken is required") + } + return Config{BotToken: token}, nil } func parseUserConfig(raw map[string]any) (UserConfig, error) { - userID := strings.TrimSpace(channel.ReadString(raw,"userId", "user_id")) - channelID := strings.TrimSpace(channel.ReadString(raw, "channelId", "channel_id")) - guildID := strings.TrimSpace(channel.ReadString(raw, "guildId", "guild_id")) - username := strings.TrimSpace(channel.ReadString(raw, "username")) + userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) + channelID := strings.TrimSpace(channel.ReadString(raw, "channelId", "channel_id")) + guildID := strings.TrimSpace(channel.ReadString(raw, "guildId", "guild_id")) + username := strings.TrimSpace(channel.ReadString(raw, "username")) - if userID == "" && channelID == "" { - return UserConfig{}, fmt.Errorf("discord user config requires user_id or channel_id") - } + if userID == "" && channelID == "" { + return UserConfig{}, errors.New("discord user config requires user_id or channel_id") + } - return UserConfig{ - UserID: userID, - ChannelID: channelID, - GuildID: guildID, - Username: username, - }, nil + return UserConfig{ + UserID: userID, + ChannelID: channelID, + GuildID: guildID, + Username: username, + }, nil } func normalizeTarget(raw string) string { - return strings.TrimSpace(raw) -} \ No newline at end of file + return strings.TrimSpace(raw) +} diff --git a/internal/channel/adapters/discord/descriptor.go b/internal/channel/adapters/discord/descriptor.go index 5e44e12a..bd93cbe0 100644 --- a/internal/channel/adapters/discord/descriptor.go +++ b/internal/channel/adapters/discord/descriptor.go @@ -2,4 +2,4 @@ package discord import "github.com/memohai/memoh/internal/channel" -const Type channel.ChannelType = "discord" \ No newline at end of file +const Type channel.ChannelType = "discord" diff --git a/internal/channel/adapters/discord/discord.go b/internal/channel/adapters/discord/discord.go index 529c3997..bfb63f9c 100644 --- a/internal/channel/adapters/discord/discord.go +++ b/internal/channel/adapters/discord/discord.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/base64" + "errors" "fmt" "io" "log/slog" @@ -14,6 +15,7 @@ import ( "unicode/utf8" "github.com/bwmarrin/discordgo" + "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/channel/adapters/common" "github.com/memohai/memoh/internal/media" @@ -57,11 +59,11 @@ func (a *DiscordAdapter) SetAssetOpener(opener assetOpener) { a.assets = opener } -func (a *DiscordAdapter) Type() channel.ChannelType { +func (*DiscordAdapter) Type() channel.ChannelType { return Type } -func (a *DiscordAdapter) Descriptor() channel.Descriptor { +func (*DiscordAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: Type, DisplayName: "Discord", @@ -159,7 +161,7 @@ func (a *DiscordAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, } text := strings.TrimSpace(m.Content) - botId := s.State.User.ID + botID := s.State.User.ID if text == "" && len(m.Attachments) == 0 { return } @@ -170,10 +172,10 @@ func (a *DiscordAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, chatType = "guild" } - isMentioned := a.isBotMentioned(m.Message, botId) + isMentioned := a.isBotMentioned(m.Message, botID) isReplyToBot := m.ReferencedMessage != nil && m.ReferencedMessage.Author != nil && - m.ReferencedMessage.Author.ID == botId + m.ReferencedMessage.Author.ID == botID msg := channel.InboundMessage{ Channel: Type, @@ -229,7 +231,7 @@ func (a *DiscordAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, return nil, fmt.Errorf("discord open connection: %w", err) } - stop := func(stopCtx context.Context) error { + stop := func(_ context.Context) error { if a.logger != nil { a.logger.Info("stop", slog.String("config_id", cfg.ID)) } @@ -256,7 +258,7 @@ func (a *DiscordAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, ms channelID := strings.TrimSpace(msg.Target) if channelID == "" { - return fmt.Errorf("discord target is required") + return errors.New("discord target is required") } // Get botID from config metadata if available @@ -268,7 +270,7 @@ func (a *DiscordAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, ms return a.sendDiscordMessage(ctx, session, channelID, botID, msg) } -func (a *DiscordAdapter) sendDiscordMessage(ctx context.Context, session *discordgo.Session, channelID, botID string, msg channel.OutboundMessage) error { +func (a *DiscordAdapter) sendDiscordMessage(ctx context.Context, session *discordgo.Session, channelID, _ string, msg channel.OutboundMessage) error { content := truncateDiscordText(msg.Message.Text) // Build message send parameters @@ -302,7 +304,7 @@ func (a *DiscordAdapter) sendDiscordMessage(ctx context.Context, session *discor // Validate: must have content or files if messageSend.Content == "" && len(messageSend.Files) == 0 { - return fmt.Errorf("cannot send empty message: no content and no valid attachments") + return errors.New("cannot send empty message: no content and no valid attachments") } _, err := session.ChannelMessageSendComplex(channelID, messageSend) @@ -317,7 +319,7 @@ func truncateDiscordText(text string) string { return string(runes[:discordMaxLength-3]) + "..." } -// discordAttachmentToFile converts a channel attachment to discordgo.File +// discordAttachmentToFile converts a channel attachment to discordgo.File. func discordAttachmentToFile(ctx context.Context, att channel.Attachment, opener assetOpener) *discordgo.File { // Get file name name := att.Name @@ -343,7 +345,7 @@ func discordAttachmentToFile(ctx context.Context, att channel.Attachment, opener if att.ContentHash != "" && botID != "" && opener != nil { if rc, _, err := opener.Open(ctx, botID, att.ContentHash); err == nil { data, _ := io.ReadAll(rc) - rc.Close() + _ = rc.Close() if len(data) > 0 { reader = bytes.NewReader(data) } @@ -360,11 +362,14 @@ func discordAttachmentToFile(ctx context.Context, att channel.Attachment, opener // Fallback to URL if reader == nil && att.URL != "" { - resp, err := http.Get(att.URL) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, att.URL, nil) if err == nil { - defer resp.Body.Close() - data, _ := io.ReadAll(resp.Body) - reader = bytes.NewReader(data) + resp, doErr := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is a Discord attachment URL received from the Discord API + if doErr == nil { + defer func() { _ = resp.Body.Close() }() + data, _ := io.ReadAll(resp.Body) + reader = bytes.NewReader(data) + } } } @@ -378,16 +383,16 @@ func discordAttachmentToFile(ctx context.Context, att channel.Attachment, opener } } -// base64DataURLToBytes decodes a base64 data URL to bytes +// base64DataURLToBytes decodes a base64 data URL to bytes. func base64DataURLToBytes(dataURL string) ([]byte, error) { parts := strings.SplitN(dataURL, ",", 2) if len(parts) != 2 { - return nil, fmt.Errorf("invalid data URL") + return nil, errors.New("invalid data URL") } return base64.StdEncoding.DecodeString(parts[1]) } -// mimeExtension returns file extension for common mime types +// mimeExtension returns file extension for common mime types. func mimeExtension(mime string) string { switch mime { case "image/jpeg", "image/jpg": @@ -417,10 +422,10 @@ func mimeExtension(mime string) string { } } -func (a *DiscordAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (a *DiscordAdapter) OpenStream(_ context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("discord target is required") + return nil, errors.New("discord target is required") } discordCfg, err := parseConfig(cfg.Credentials) @@ -442,7 +447,7 @@ func (a *DiscordAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConf }, nil } -func (a *DiscordAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (a *DiscordAdapter) ProcessingStarted(_ context.Context, cfg channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { chatID := strings.TrimSpace(info.ReplyTarget) if chatID == "" { return channel.ProcessingStatusHandle{}, nil @@ -463,15 +468,15 @@ func (a *DiscordAdapter) ProcessingStarted(ctx context.Context, cfg channel.Chan return channel.ProcessingStatusHandle{}, err } -func (a *DiscordAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (*DiscordAdapter) ProcessingCompleted(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle) error { return nil } -func (a *DiscordAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (*DiscordAdapter) ProcessingFailed(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle, _ error) error { return nil } -func (a *DiscordAdapter) React(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { +func (a *DiscordAdapter) React(_ context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { discordCfg, err := parseConfig(cfg.Credentials) if err != nil { return err @@ -485,7 +490,7 @@ func (a *DiscordAdapter) React(ctx context.Context, cfg channel.ChannelConfig, t return session.MessageReactionAdd(target, messageID, emoji) } -func (a *DiscordAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { +func (a *DiscordAdapter) Unreact(_ context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { discordCfg, err := parseConfig(cfg.Credentials) if err != nil { return err @@ -499,31 +504,31 @@ func (a *DiscordAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, return session.MessageReactionRemove(target, messageID, emoji, "@me") } -func (a *DiscordAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (*DiscordAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) } -func (a *DiscordAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func (*DiscordAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return normalizeUserConfig(raw) } -func (a *DiscordAdapter) NormalizeTarget(raw string) string { +func (*DiscordAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } -func (a *DiscordAdapter) ResolveTarget(userConfig map[string]any) (string, error) { +func (*DiscordAdapter) ResolveTarget(userConfig map[string]any) (string, error) { return resolveTarget(userConfig) } -func (a *DiscordAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { +func (*DiscordAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { return matchBinding(config, criteria) } -func (a *DiscordAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (*DiscordAdapter) BuildUserConfig(identity channel.Identity) map[string]any { return buildUserConfig(identity) } -func (a *DiscordAdapter) collectAttachments(msg *discordgo.Message) []channel.Attachment { +func (*DiscordAdapter) collectAttachments(msg *discordgo.Message) []channel.Attachment { if msg == nil || len(msg.Attachments) == 0 { return nil } @@ -558,7 +563,7 @@ func (a *DiscordAdapter) collectAttachments(msg *discordgo.Message) []channel.At return attachments } -func (a *DiscordAdapter) isBotMentioned(msg *discordgo.Message, botID string) bool { +func (*DiscordAdapter) isBotMentioned(msg *discordgo.Message, botID string) bool { if msg == nil { return false } diff --git a/internal/channel/adapters/discord/discord_test.go b/internal/channel/adapters/discord/discord_test.go index f62c8343..5009cbe7 100644 --- a/internal/channel/adapters/discord/discord_test.go +++ b/internal/channel/adapters/discord/discord_test.go @@ -4,7 +4,6 @@ import ( "testing" ) - func TestMimeExtension(t *testing.T) { tests := []struct { mime string diff --git a/internal/channel/adapters/discord/stream.go b/internal/channel/adapters/discord/stream.go index 59ce66ea..686675f2 100644 --- a/internal/channel/adapters/discord/stream.go +++ b/internal/channel/adapters/discord/stream.go @@ -2,6 +2,7 @@ package discord import ( "context" + "errors" "fmt" "strings" "sync" @@ -9,129 +10,130 @@ import ( "time" "github.com/bwmarrin/discordgo" + "github.com/memohai/memoh/internal/channel" ) type discordOutboundStream struct { - adapter *DiscordAdapter - cfg channel.ChannelConfig - target string - reply *channel.ReplyRef - session *discordgo.Session - closed atomic.Bool - mu sync.Mutex - msgID string - buffer strings.Builder - lastUpdate time.Time + adapter *DiscordAdapter + cfg channel.ChannelConfig + target string + reply *channel.ReplyRef + session *discordgo.Session + closed atomic.Bool + mu sync.Mutex + msgID string + buffer strings.Builder + lastUpdate time.Time } func (s *discordOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { - if s == nil || s.adapter == nil { - return fmt.Errorf("discord stream not configured") - } - if s.closed.Load() { - return fmt.Errorf("discord stream is closed") - } + if s == nil || s.adapter == nil { + return errors.New("discord stream not configured") + } + if s.closed.Load() { + return errors.New("discord stream is closed") + } - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } - switch event.Type { - case channel.StreamEventStatus: - if event.Status == channel.StreamStatusStarted { + switch event.Type { + case channel.StreamEventStatus: + if event.Status == channel.StreamStatusStarted { return s.ensureMessage("Thinking...") - } - return nil + } + return nil - case channel.StreamEventDelta: - if event.Delta == "" || event.Phase == channel.StreamPhaseReasoning { - return nil - } - s.mu.Lock() - s.buffer.WriteString(event.Delta) - s.mu.Unlock() + case channel.StreamEventDelta: + if event.Delta == "" || event.Phase == channel.StreamPhaseReasoning { + return nil + } + s.mu.Lock() + s.buffer.WriteString(event.Delta) + s.mu.Unlock() - // Discord has strict rate limits, only update periodically - if time.Since(s.lastUpdate) > 2*time.Second { + // Discord has strict rate limits, only update periodically + if time.Since(s.lastUpdate) > 2*time.Second { return s.updateMessage() - } - return nil + } + return nil - case channel.StreamEventFinal: - if event.Final != nil && !event.Final.Message.IsEmpty() { - finalText := strings.TrimSpace(event.Final.Message.PlainText()) - if finalText != "" { + case channel.StreamEventFinal: + if event.Final != nil && !event.Final.Message.IsEmpty() { + finalText := strings.TrimSpace(event.Final.Message.PlainText()) + if finalText != "" { return s.finalizeMessage(finalText) - } - } - s.mu.Lock() - finalText := strings.TrimSpace(s.buffer.String()) - s.mu.Unlock() - if finalText != "" { + } + } + s.mu.Lock() + finalText := strings.TrimSpace(s.buffer.String()) + s.mu.Unlock() + if finalText != "" { return s.finalizeMessage(finalText) - } - return nil + } + return nil - case channel.StreamEventError: - errText := strings.TrimSpace(event.Error) - if errText == "" { - return nil - } + case channel.StreamEventError: + errText := strings.TrimSpace(event.Error) + if errText == "" { + return nil + } return s.finalizeMessage("Error: " + errText) - case channel.StreamEventAttachment: - if len(event.Attachments) == 0 { - return nil - } - // Finalize current text message before sending attachments - s.mu.Lock() - finalText := strings.TrimSpace(s.buffer.String()) - s.mu.Unlock() - if finalText != "" { - if err := s.finalizeMessage(finalText); err != nil { - return err - } - } - // Send attachments - for _, att := range event.Attachments { - if err := s.sendAttachment(att); err != nil { - return err - } - } - return nil + case channel.StreamEventAttachment: + if len(event.Attachments) == 0 { + return nil + } + // Finalize current text message before sending attachments + s.mu.Lock() + finalText := strings.TrimSpace(s.buffer.String()) + s.mu.Unlock() + if finalText != "" { + if err := s.finalizeMessage(finalText); err != nil { + return err + } + } + // Send attachments + for _, att := range event.Attachments { + if err := s.sendAttachment(ctx, att); err != nil { + return err + } + } + return nil - case channel.StreamEventAgentStart, channel.StreamEventAgentEnd, channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd, channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted, channel.StreamEventProcessingFailed, channel.StreamEventToolCallStart, channel.StreamEventToolCallEnd: - // Status events - no action needed for Discord - return nil + case channel.StreamEventAgentStart, channel.StreamEventAgentEnd, channel.StreamEventPhaseStart, channel.StreamEventPhaseEnd, channel.StreamEventProcessingStarted, channel.StreamEventProcessingCompleted, channel.StreamEventProcessingFailed, channel.StreamEventToolCallStart, channel.StreamEventToolCallEnd: + // Status events - no action needed for Discord + return nil - default: - return fmt.Errorf("unsupported stream event type: %s", event.Type) - } + default: + return fmt.Errorf("unsupported stream event type: %s", event.Type) + } } func (s *discordOutboundStream) Close(ctx context.Context) error { - if s == nil { - return nil - } - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - s.closed.Store(true) - return nil + if s == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.closed.Store(true) + return nil } func (s *discordOutboundStream) ensureMessage(text string) error { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() - if s.msgID != "" { - return nil - } + if s.msgID != "" { + return nil + } content := truncateDiscordText(text) @@ -145,46 +147,46 @@ func (s *discordOutboundStream) ensureMessage(text string) error { } else { msg, err = s.session.ChannelMessageSend(s.target, content) } - if err != nil { - return err - } + if err != nil { + return err + } - s.msgID = msg.ID - s.lastUpdate = time.Now() - return nil + s.msgID = msg.ID + s.lastUpdate = time.Now() + return nil } func (s *discordOutboundStream) updateMessage() error { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() - if s.msgID == "" { - return nil - } + if s.msgID == "" { + return nil + } - content := s.buffer.String() - if content == "" { - return nil - } + content := s.buffer.String() + if content == "" { + return nil + } content = truncateDiscordText(content) - _, err := s.session.ChannelMessageEdit(s.target, s.msgID, content) - if err != nil { - return err - } + _, err := s.session.ChannelMessageEdit(s.target, s.msgID, content) + if err != nil { + return err + } - s.lastUpdate = time.Now() - return nil + s.lastUpdate = time.Now() + return nil } func (s *discordOutboundStream) finalizeMessage(text string) error { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() text = truncateDiscordText(text) - if s.msgID == "" { + if s.msgID == "" { var msg *discordgo.Message var err error if s.reply != nil && s.reply.MessageID != "" { @@ -201,14 +203,13 @@ func (s *discordOutboundStream) finalizeMessage(text string) error { s.msgID = msg.ID s.lastUpdate = time.Now() return nil - } + } - _, err := s.session.ChannelMessageEdit(s.target, s.msgID, text) - return err + _, err := s.session.ChannelMessageEdit(s.target, s.msgID, text) + return err } -func (s *discordOutboundStream) sendAttachment(att channel.Attachment) error { - ctx := context.Background() +func (s *discordOutboundStream) sendAttachment(ctx context.Context, att channel.Attachment) error { file := discordAttachmentToFile(ctx, att, s.adapter.assets) if file == nil { return nil @@ -228,4 +229,4 @@ func (s *discordOutboundStream) sendAttachment(att channel.Attachment) error { _, err := s.session.ChannelMessageSendComplex(s.target, messageSend) return err -} \ No newline at end of file +} diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go index 3ab38e8b..7a0e87df 100644 --- a/internal/channel/adapters/feishu/config.go +++ b/internal/channel/adapters/feishu/config.go @@ -1,7 +1,7 @@ package feishu import ( - "fmt" + "errors" "strings" lark "github.com/larksuite/oapi-sdk-go/v3" @@ -79,7 +79,7 @@ func resolveTarget(raw map[string]any) (string, error) { if cfg.UserID != "" { return "user_id:" + cfg.UserID, nil } - return "", fmt.Errorf("feishu binding is incomplete") + return "", errors.New("feishu binding is incomplete") } func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { @@ -126,7 +126,7 @@ func parseConfig(raw map[string]any) (Config, error) { return Config{}, err } if appID == "" || appSecret == "" { - return Config{}, fmt.Errorf("feishu appId and appSecret are required") + return Config{}, errors.New("feishu appId and appSecret are required") } return Config{ AppID: appID, @@ -142,7 +142,7 @@ func parseUserConfig(raw map[string]any) (UserConfig, error) { openID := strings.TrimSpace(channel.ReadString(raw, "openId", "open_id")) userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) if openID == "" && userID == "" { - return UserConfig{}, fmt.Errorf("feishu user config requires open_id or user_id") + return UserConfig{}, errors.New("feishu user config requires open_id or user_id") } return UserConfig{OpenID: openID, UserID: userID}, nil } @@ -171,7 +171,7 @@ func normalizeRegion(raw string) (string, error) { case regionLark, "global", "intl", "international": return regionLark, nil default: - return "", fmt.Errorf("feishu region must be feishu or lark") + return "", errors.New("feishu region must be feishu or lark") } } @@ -182,7 +182,7 @@ func normalizeInboundMode(raw string) (string, error) { case inboundModeWebhook: return inboundModeWebhook, nil default: - return "", fmt.Errorf("feishu inbound_mode must be websocket or webhook") + return "", errors.New("feishu inbound_mode must be websocket or webhook") } } diff --git a/internal/channel/adapters/feishu/connect_mode_test.go b/internal/channel/adapters/feishu/connect_mode_test.go index d47f1c7a..8b40ec96 100644 --- a/internal/channel/adapters/feishu/connect_mode_test.go +++ b/internal/channel/adapters/feishu/connect_mode_test.go @@ -21,7 +21,7 @@ func TestConnectWebhookModeDoesNotStartWebsocket(t *testing.T) { "inbound_mode": "webhook", }, } - conn, err := adapter.Connect(context.Background(), cfg, func(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) error { + conn, err := adapter.Connect(context.Background(), cfg, func(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage) error { return nil }) if err != nil { diff --git a/internal/channel/adapters/feishu/directory.go b/internal/channel/adapters/feishu/directory.go index 4b6f0cde..bfcbe084 100644 --- a/internal/channel/adapters/feishu/directory.go +++ b/internal/channel/adapters/feishu/directory.go @@ -2,6 +2,7 @@ package feishu import ( "context" + "errors" "fmt" "strings" @@ -28,7 +29,7 @@ func directoryLimit(n int) int { } // ListPeers lists users (peers) from Feishu contact, optionally filtered by query. -func (a *FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err @@ -60,7 +61,7 @@ func (a *FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig } // ListGroups lists chat groups from Feishu IM, optionally filtered by query. -func (a *FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err @@ -104,17 +105,15 @@ func (a *FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfi } // ListGroupMembers lists members of a Feishu chat group. -func (a *FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err } chatID := strings.TrimSpace(groupID) - if strings.HasPrefix(chatID, "chat_id:") { - chatID = strings.TrimPrefix(chatID, "chat_id:") - } + chatID = strings.TrimPrefix(chatID, "chat_id:") if chatID == "" { - return nil, fmt.Errorf("feishu list group members: empty group id") + return nil, errors.New("feishu list group members: empty group id") } client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) pageSize := directoryLimit(query.Limit) @@ -159,7 +158,7 @@ func (a *FeishuAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelCon } } -func (a *FeishuAdapter) resolveUser(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { +func (*FeishuAdapter) resolveUser(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { userID, userIDType := parseFeishuUserInput(input) if userID == "" { return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry user: invalid input %q", input) @@ -176,16 +175,14 @@ func (a *FeishuAdapter) resolveUser(ctx context.Context, client *lark.Client, in return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: code=%d msg=%s", resp.Code, resp.Msg) } if resp.Data == nil || resp.Data.User == nil { - return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: empty response") + return channel.DirectoryEntry{}, errors.New("feishu get user: empty response") } return feishuUserToEntry(resp.Data.User), nil } -func (a *FeishuAdapter) resolveGroup(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { +func (*FeishuAdapter) resolveGroup(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { chatID := strings.TrimSpace(input) - if strings.HasPrefix(chatID, "chat_id:") { - chatID = strings.TrimPrefix(chatID, "chat_id:") - } + chatID = strings.TrimPrefix(chatID, "chat_id:") if chatID == "" { return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry group: invalid input %q", input) } diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index 4277522b..b5e55399 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -52,7 +53,7 @@ type larkProcessingReactionGateway struct { func (g *larkProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { if g == nil || g.api == nil { - return "", fmt.Errorf("feishu reaction api not configured") + return "", errors.New("feishu reaction api not configured") } req := larkim.NewCreateMessageReactionReqBuilder(). MessageId(messageID). @@ -74,14 +75,14 @@ func (g *larkProcessingReactionGateway) Add(ctx context.Context, messageID, reac return "", fmt.Errorf("feishu add reaction failed: %s (code: %d)", msg, code) } if resp.Data == nil || resp.Data.ReactionId == nil || strings.TrimSpace(*resp.Data.ReactionId) == "" { - return "", fmt.Errorf("feishu add reaction failed: empty reaction id") + return "", errors.New("feishu add reaction failed: empty reaction id") } return strings.TrimSpace(*resp.Data.ReactionId), nil } func (g *larkProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { if g == nil || g.api == nil { - return fmt.Errorf("feishu reaction api not configured") + return errors.New("feishu reaction api not configured") } req := larkim.NewDeleteMessageReactionReqBuilder(). MessageId(messageID). @@ -119,12 +120,12 @@ func (a *FeishuAdapter) SetAssetOpener(opener assetOpener) { } // Type returns the Feishu channel type. -func (a *FeishuAdapter) Type() channel.ChannelType { +func (*FeishuAdapter) Type() channel.ChannelType { return Type } // Descriptor returns the Feishu channel metadata. -func (a *FeishuAdapter) Descriptor() channel.Descriptor { +func (*FeishuAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: Type, DisplayName: "Feishu", @@ -186,7 +187,7 @@ func (a *FeishuAdapter) Descriptor() channel.Descriptor { } // ProcessingStarted adds a transient reaction to indicate the inbound message is being processed. -func (a *FeishuAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (a *FeishuAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { messageID := strings.TrimSpace(info.SourceMessageID) if messageID == "" { return channel.ProcessingStatusHandle{}, nil @@ -203,7 +204,7 @@ func (a *FeishuAdapter) ProcessingStarted(ctx context.Context, cfg channel.Chann } // ProcessingCompleted removes the transient processing reaction before output is sent. -func (a *FeishuAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (a *FeishuAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { messageID := strings.TrimSpace(info.SourceMessageID) reactionID := strings.TrimSpace(handle.Token) if messageID == "" || reactionID == "" { @@ -217,17 +218,17 @@ func (a *FeishuAdapter) ProcessingCompleted(ctx context.Context, cfg channel.Cha } // ProcessingFailed removes the transient processing reaction when chat processing fails. -func (a *FeishuAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (a *FeishuAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, _ error) error { return a.ProcessingCompleted(ctx, cfg, msg, info, handle) } -func (a *FeishuAdapter) processingReactionGateway(cfg channel.ChannelConfig) (processingReactionGateway, error) { +func (*FeishuAdapter) processingReactionGateway(cfg channel.ChannelConfig) (processingReactionGateway, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err } client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret, lark.WithOpenBaseUrl(feishuCfg.openBaseURL())) - gateway := &larkProcessingReactionGateway{api: client.Im.V1.MessageReaction} + gateway := &larkProcessingReactionGateway{api: client.Im.MessageReaction} return gateway, nil } @@ -259,7 +260,7 @@ func (a *FeishuAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, func addProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionType string) (string, error) { if gateway == nil { - return "", fmt.Errorf("processing reaction gateway is nil") + return "", errors.New("processing reaction gateway is nil") } msgID := strings.TrimSpace(messageID) if msgID == "" { @@ -267,14 +268,14 @@ func addProcessingReaction(ctx context.Context, gateway processingReactionGatewa } rxType := strings.TrimSpace(reactionType) if rxType == "" { - return "", fmt.Errorf("processing reaction type is empty") + return "", errors.New("processing reaction type is empty") } return gateway.Add(ctx, msgID, rxType) } func removeProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionID string) error { if gateway == nil { - return fmt.Errorf("processing reaction gateway is nil") + return errors.New("processing reaction gateway is nil") } msgID := strings.TrimSpace(messageID) rxID := strings.TrimSpace(reactionID) @@ -285,7 +286,7 @@ func removeProcessingReaction(ctx context.Context, gateway processingReactionGat } // DiscoverSelf retrieves the bot's own identity from the Feishu platform. -func (a *FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { +func (*FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { cfg, err := parseConfig(credentials) if err != nil { return nil, "", err @@ -312,7 +313,7 @@ func (a *FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string } openID := strings.TrimSpace(body.Bot.OpenID) if openID == "" { - return nil, "", fmt.Errorf("feishu discover self: empty open_id") + return nil, "", errors.New("feishu discover self: empty open_id") } identity := map[string]any{ "open_id": openID, @@ -327,32 +328,32 @@ func (a *FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string } // NormalizeConfig validates and normalizes a Feishu channel configuration map. -func (a *FeishuAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (*FeishuAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) } // NormalizeUserConfig validates and normalizes a Feishu user-binding configuration map. -func (a *FeishuAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func (*FeishuAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return normalizeUserConfig(raw) } // NormalizeTarget normalizes a Feishu delivery target string. -func (a *FeishuAdapter) NormalizeTarget(raw string) string { +func (*FeishuAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } // ResolveTarget derives a delivery target from a Feishu user-binding configuration. -func (a *FeishuAdapter) ResolveTarget(userConfig map[string]any) (string, error) { +func (*FeishuAdapter) ResolveTarget(userConfig map[string]any) (string, error) { return resolveTarget(userConfig) } // MatchBinding reports whether a Feishu user binding matches the given criteria. -func (a *FeishuAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { +func (*FeishuAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { return matchBinding(config, criteria) } // BuildUserConfig constructs a Feishu user-binding config from an Identity. -func (a *FeishuAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (*FeishuAdapter) BuildUserConfig(identity channel.Identity) map[string]any { return buildUserConfig(identity) } @@ -388,7 +389,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, if connCtx.Err() != nil { return nil } - msg := extractFeishuInbound(event, botOpenID) + msg := extractFeishuInbound(event, botOpenID, a.logger) text := msg.Message.PlainText() rawMessageID := "" rawMessageType := "" @@ -398,7 +399,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, rawMessageID = strings.TrimSpace(*event.Event.Message.MessageId) } if event.Event.Message.MessageType != nil { - rawMessageType = strings.TrimSpace(string(*event.Event.Message.MessageType)) + rawMessageType = strings.TrimSpace(*event.Event.Message.MessageType) } if event.Event.Message.Content != nil { rawContent = common.SummarizeText(*event.Event.Message.Content) @@ -549,7 +550,7 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg msgType = larkim.MsgTypeText text := strings.TrimSpace(msg.Message.PlainText()) if text == "" { - return fmt.Errorf("message is required") + return errors.New("message is required") } payload, marshalErr := json.Marshal(map[string]string{"text": text}) if marshalErr != nil { @@ -578,11 +579,11 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg Uuid(uuid.NewString()). Build()). Build() - resp, err := client.Im.V1.Message.Reply(ctx, replyReq) + resp, err := client.Im.Message.Reply(ctx, replyReq) return a.handleReplyResponse(cfg.ID, resp, err) } - resp, err := client.Im.V1.Message.Create(ctx, req) + resp, err := client.Im.Message.Create(ctx, req) return a.handleResponse(cfg.ID, resp, err) } @@ -591,7 +592,7 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg func (a *FeishuAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("feishu target is required") + return nil, errors.New("feishu target is required") } feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { @@ -634,7 +635,7 @@ func (a *FeishuAdapter) handleReplyResponse(configID string, resp *larkim.ReplyM msg = resp.Msg } if a.logger != nil { - a.logger.Error("reply failed", slog.String("config_id", configID), slog.Int("code", code), slog.String("msg", msg)) + a.logger.Error("reply failed", slog.String("config_id", configID), slog.Int("code", code), slog.String("error_message", msg)) } return fmt.Errorf("feishu reply failed: %s (code: %d)", msg, code) } @@ -659,7 +660,7 @@ func (a *FeishuAdapter) handleResponse(configID string, resp *larkim.CreateMessa msg = resp.Msg } if a.logger != nil { - a.logger.Error("send failed", slog.String("config_id", configID), slog.Int("code", code), slog.String("msg", msg)) + a.logger.Error("send failed", slog.String("config_id", configID), slog.Int("code", code), slog.String("error_message", msg)) } return fmt.Errorf("feishu send failed: %s (code: %d)", msg, code) } @@ -701,7 +702,7 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, Image(reader). Build()). Build() - uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) + uploadResp, err := client.Im.Image.Create(ctx, uploadReq) if err != nil { return fmt.Errorf("failed to upload image: %w", err) } @@ -727,7 +728,7 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, File(reader). Build()). Build() - uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) + uploadResp, err := client.Im.File.Create(ctx, uploadReq) if err != nil { return fmt.Errorf("failed to upload file: %w", err) } @@ -757,7 +758,7 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, Build()). Build() - sendResp, err := client.Im.V1.Message.Create(ctx, req) + sendResp, err := client.Im.Message.Create(ctx, req) return a.handleResponse("", sendResp, err) } @@ -809,14 +810,14 @@ func (a *FeishuAdapter) resolveAttachmentUploadReader(ctx context.Context, att c } if downloadURL == "" { - return nil, "", "", fmt.Errorf("attachment reference is required: provide platform_key/content_hash/base64/url") + return nil, "", "", errors.New("attachment reference is required: provide platform_key/content_hash/base64/url") } httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) if err != nil { return nil, "", "", fmt.Errorf("failed to build download request: %w", err) } httpClient := &http.Client{Timeout: 60 * time.Second} - resp, err := httpClient.Do(httpReq) + resp, err := httpClient.Do(httpReq) //nolint:gosec // G704: URL is a Feishu file download URL from the Feishu API if err != nil { return nil, "", "", fmt.Errorf("failed to download attachment: %w", err) } @@ -835,10 +836,10 @@ func (a *FeishuAdapter) resolveAttachmentUploadReader(ctx context.Context, att c // User-sent resources must be fetched via the message-resource API which // requires both message_id and file_key. The message_id is expected in // attachment.Metadata["message_id"]. -func (a *FeishuAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) { +func (*FeishuAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) { platformKey := strings.TrimSpace(attachment.PlatformKey) if platformKey == "" { - return channel.AttachmentPayload{}, fmt.Errorf("feishu attachment platform_key is required") + return channel.AttachmentPayload{}, errors.New("feishu attachment platform_key is required") } messageID := "" if attachment.Metadata != nil { @@ -847,7 +848,7 @@ func (a *FeishuAdapter) ResolveAttachment(ctx context.Context, cfg channel.Chann } } if messageID == "" { - return channel.AttachmentPayload{}, fmt.Errorf("feishu attachment metadata.message_id is required") + return channel.AttachmentPayload{}, errors.New("feishu attachment metadata.message_id is required") } feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { @@ -864,7 +865,7 @@ func (a *FeishuAdapter) ResolveAttachment(ctx context.Context, cfg channel.Chann FileKey(platformKey). Type(resourceType). Build() - resp, err := client.Im.V1.MessageResource.Get(ctx, req) + resp, err := client.Im.MessageResource.Get(ctx, req) if err != nil { return channel.AttachmentPayload{}, fmt.Errorf("download feishu resource: %w", err) } @@ -872,7 +873,7 @@ func (a *FeishuAdapter) ResolveAttachment(ctx context.Context, cfg channel.Chann return channel.AttachmentPayload{}, fmt.Errorf("download feishu resource: %s (code: %d)", resp.Msg, resp.Code) } if resp.File == nil { - return channel.AttachmentPayload{}, fmt.Errorf("download feishu resource: empty payload") + return channel.AttachmentPayload{}, errors.New("download feishu resource: empty payload") } mime := strings.TrimSpace(attachment.Mime) if mime == "" { @@ -925,7 +926,7 @@ func resolveFeishuFileType(name, mime string) string { } } -func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { +func (*FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { type postContent struct { ZhCn struct { Title string `json:"title"` diff --git a/internal/channel/adapters/feishu/feishu_integration_test.go b/internal/channel/adapters/feishu/feishu_integration_test.go index d749556b..9bcaf6b1 100644 --- a/internal/channel/adapters/feishu/feishu_integration_test.go +++ b/internal/channel/adapters/feishu/feishu_integration_test.go @@ -76,7 +76,9 @@ func TestFeishuGateway_Integration(t *testing.T) { Text: "【Memoh 集成测试】主动推送验证成功。", }, } - _ = adapter.Send(context.Background(), c, pushMsg) + pushCtx, pushCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer pushCancel() + _ = adapter.Send(pushCtx, c, pushMsg) }() return nil diff --git a/internal/channel/adapters/feishu/feishu_test.go b/internal/channel/adapters/feishu/feishu_test.go index 6f6e5ad0..e8e8f4ae 100644 --- a/internal/channel/adapters/feishu/feishu_test.go +++ b/internal/channel/adapters/feishu/feishu_test.go @@ -22,7 +22,7 @@ type fakeProcessingReactionGateway struct { removeErr error } -func (g *fakeProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { +func (g *fakeProcessingReactionGateway) Add(_ context.Context, messageID, reactionType string) (string, error) { g.addCalls = append(g.addCalls, struct{ messageID, reactionType string }{ messageID: messageID, reactionType: reactionType, @@ -35,7 +35,7 @@ func (g *fakeProcessingReactionGateway) Add(ctx context.Context, messageID, reac return resp.reactionID, resp.err } -func (g *fakeProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { +func (g *fakeProcessingReactionGateway) Remove(_ context.Context, messageID, reactionID string) error { g.removeCalls = append(g.removeCalls, struct{ messageID, reactionID string }{ messageID: messageID, reactionID: reactionID, diff --git a/internal/channel/adapters/feishu/inbound.go b/internal/channel/adapters/feishu/inbound.go index ee1ed02b..bd2b3e80 100644 --- a/internal/channel/adapters/feishu/inbound.go +++ b/internal/channel/adapters/feishu/inbound.go @@ -2,6 +2,7 @@ package feishu import ( "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -14,7 +15,11 @@ import ( // extractFeishuInbound converts a Feishu P2MessageReceiveV1 event into a channel.InboundMessage. // botOpenID is the bot's own open_id used to filter mentions; if empty, any mention is treated as bot mention. -func extractFeishuInbound(event *larkim.P2MessageReceiveV1, botOpenID string) channel.InboundMessage { +func extractFeishuInbound(event *larkim.P2MessageReceiveV1, botOpenID string, loggers ...*slog.Logger) channel.InboundMessage { + var log *slog.Logger + if len(loggers) > 0 { + log = loggers[0] + } if event == nil || event.Event == nil || event.Event.Message == nil { return channel.InboundMessage{Channel: Type} } @@ -28,7 +33,9 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1, botOpenID string) ch var contentMap map[string]any if message.Content != nil { if err := json.Unmarshal([]byte(*message.Content), &contentMap); err != nil { - slog.Warn("feishu inbound: unmarshal content failed", slog.Any("error", err)) + if log != nil { + log.Warn("feishu inbound: unmarshal content failed", slog.Any("error", err)) + } } } isMentioned := isFeishuBotMentioned(contentMap, message.Mentions, botOpenID) @@ -45,15 +52,15 @@ func extractFeishuInbound(event *larkim.P2MessageReceiveV1, botOpenID string) ch msg.Text = postText } postAtts := extractFeishuPostAttachments(contentMap, msg.ID) - for _, att := range postAtts { - msg.Attachments = append(msg.Attachments, att) - } + msg.Attachments = append(msg.Attachments, postAtts...) if len(postAtts) > 0 || postText != "" { - slog.Debug("feishu post extracted", - "message_id", msg.ID, - "text_len", len(postText), - "attachments", len(postAtts), - ) + if log != nil { + log.Debug("feishu post extracted", + slog.String("message_id", msg.ID), + slog.Int("text_len", len(postText)), + slog.Int("attachments", len(postAtts)), + ) + } } case larkim.MsgTypeImage: if key, ok := contentMap["image_key"].(string); ok { @@ -165,10 +172,7 @@ func isFeishuBotMentioned(contentMap map[string]any, mentions []*larkim.MentionE return true } } - if matchFeishuContentMention(contentMap, botOpenID) { - return true - } - return false + return matchFeishuContentMention(contentMap, botOpenID) } // hasAnyFeishuMention is the fallback when the bot's open_id is unknown. @@ -373,7 +377,7 @@ func stringValue(raw any) string { // resolveFeishuReceiveID parses target (open_id:/user_id:/chat_id: prefix) and returns receiveID and receiveType. func resolveFeishuReceiveID(raw string) (string, string, error) { if raw == "" { - return "", "", fmt.Errorf("feishu target is required") + return "", "", errors.New("feishu target is required") } if strings.HasPrefix(raw, "open_id:") { return strings.TrimPrefix(raw, "open_id:"), larkim.ReceiveIdTypeOpenId, nil diff --git a/internal/channel/adapters/feishu/sender_profile.go b/internal/channel/adapters/feishu/sender_profile.go index 5fd14de6..fb81702a 100644 --- a/internal/channel/adapters/feishu/sender_profile.go +++ b/internal/channel/adapters/feishu/sender_profile.go @@ -2,7 +2,9 @@ package feishu import ( "context" + "errors" "fmt" + "log/slog" "strings" "time" @@ -46,9 +48,6 @@ func (a *FeishuAdapter) enrichSenderProfile(ctx context.Context, cfg channel.Cha chatID = strings.TrimSpace(*event.Event.Message.ChatId) } - if ctx == nil { - ctx = context.Background() - } lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() @@ -56,11 +55,11 @@ func (a *FeishuAdapter) enrichSenderProfile(ctx context.Context, cfg channel.Cha if err != nil { if a.logger != nil { a.logger.Debug("feishu sender profile lookup failed", - "config_id", cfg.ID, - "open_id", openID, - "user_id", userID, - "chat_id", chatID, - "error", err, + slog.String("config_id", cfg.ID), + slog.String("open_id", openID), + slog.String("user_id", userID), + slog.String("chat_id", chatID), + slog.Any("error", err), ) } } @@ -70,7 +69,7 @@ func (a *FeishuAdapter) enrichSenderProfile(ctx context.Context, cfg channel.Cha applySenderProfile(msg, profile) } -func (a *FeishuAdapter) lookupSenderProfile(ctx context.Context, cfg channel.ChannelConfig, openID, userID, chatID string) (feishuSenderProfile, error) { +func (*FeishuAdapter) lookupSenderProfile(ctx context.Context, cfg channel.ChannelConfig, openID, userID, chatID string) (feishuSenderProfile, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return feishuSenderProfile{}, err @@ -79,9 +78,7 @@ func (a *FeishuAdapter) lookupSenderProfile(ctx context.Context, cfg channel.Cha var lastErr error chatID = strings.TrimSpace(chatID) - if strings.HasPrefix(chatID, "chat_id:") { - chatID = strings.TrimPrefix(chatID, "chat_id:") - } + chatID = strings.TrimPrefix(chatID, "chat_id:") // Group scene: chat members has the highest chance to return a human-readable name. if chatID != "" && openID != "" { @@ -125,7 +122,7 @@ func lookupSenderProfileFromContact(ctx context.Context, client *lark.Client, op idType = larkcontact.UserIdTypeUserId } if lookupID == "" { - return feishuSenderProfile{}, fmt.Errorf("empty sender id") + return feishuSenderProfile{}, errors.New("empty sender id") } req := larkcontact.NewGetUserReqBuilder(). UserIdType(idType). @@ -145,7 +142,7 @@ func lookupSenderProfileFromContact(ctx context.Context, client *lark.Client, op return feishuSenderProfile{}, fmt.Errorf("feishu get user failed: code=%d msg=%s", code, msg) } if resp.Data == nil || resp.Data.User == nil { - return feishuSenderProfile{}, fmt.Errorf("feishu get user returned empty user") + return feishuSenderProfile{}, errors.New("feishu get user returned empty user") } displayName := ptrStr(resp.Data.User.Name) username := ptrStr(resp.Data.User.Nickname) @@ -162,7 +159,7 @@ func lookupSenderProfileFromGroupMember(ctx context.Context, client *lark.Client memberIDType = strings.TrimSpace(memberIDType) memberID = strings.TrimSpace(memberID) if memberIDType == "" || memberID == "" { - return feishuSenderProfile{}, fmt.Errorf("empty member lookup input") + return feishuSenderProfile{}, errors.New("empty member lookup input") } pageToken := "" for page := 0; page < 5; page++ { diff --git a/internal/channel/adapters/feishu/stream.go b/internal/channel/adapters/feishu/stream.go index 9bb3d4e2..3e93e6df 100644 --- a/internal/channel/adapters/feishu/stream.go +++ b/internal/channel/adapters/feishu/stream.go @@ -3,6 +3,7 @@ package feishu import ( "context" "encoding/json" + "errors" "fmt" "regexp" "strings" @@ -41,10 +42,10 @@ type feishuOutboundStream struct { func (s *feishuOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { if s == nil || s.adapter == nil { - return fmt.Errorf("feishu stream not configured") + return errors.New("feishu stream not configured") } if s.closed.Load() { - return fmt.Errorf("feishu stream is closed") + return errors.New("feishu stream is closed") } select { case <-ctx.Done(): @@ -163,7 +164,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro return nil } if s.client == nil { - return fmt.Errorf("feishu client not configured") + return errors.New("feishu client not configured") } content, err := buildFeishuStreamCardContent(text) if err != nil { @@ -178,7 +179,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro Uuid(uuid.NewString()). Build()). Build() - replyResp, err := s.client.Im.V1.Message.Reply(ctx, replyReq) + replyResp, err := s.client.Im.Message.Reply(ctx, replyReq) if err != nil { return err } @@ -190,7 +191,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro return fmt.Errorf("feishu stream reply failed: %s (code: %d)", msg, code) } if replyResp.Data == nil || replyResp.Data.MessageId == nil || strings.TrimSpace(*replyResp.Data.MessageId) == "" { - return fmt.Errorf("feishu stream reply failed: empty message id") + return errors.New("feishu stream reply failed: empty message id") } s.cardMessageID = strings.TrimSpace(*replyResp.Data.MessageId) s.lastPatched = normalizeFeishuStreamText(text) @@ -206,7 +207,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro Uuid(uuid.NewString()). Build()). Build() - createResp, err := s.client.Im.V1.Message.Create(ctx, createReq) + createResp, err := s.client.Im.Message.Create(ctx, createReq) if err != nil { return err } @@ -218,7 +219,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro return fmt.Errorf("feishu stream create failed: %s (code: %d)", msg, code) } if createResp.Data == nil || createResp.Data.MessageId == nil || strings.TrimSpace(*createResp.Data.MessageId) == "" { - return fmt.Errorf("feishu stream create failed: empty message id") + return errors.New("feishu stream create failed: empty message id") } s.cardMessageID = strings.TrimSpace(*createResp.Data.MessageId) s.lastPatched = normalizeFeishuStreamText(text) @@ -228,7 +229,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro func (s *feishuOutboundStream) patchCard(ctx context.Context, text string) error { if strings.TrimSpace(s.cardMessageID) == "" { - return fmt.Errorf("feishu stream card message not initialized") + return errors.New("feishu stream card message not initialized") } contentText := normalizeFeishuStreamText(text) if contentText == s.lastPatched { @@ -244,7 +245,7 @@ func (s *feishuOutboundStream) patchCard(ctx context.Context, text string) error Content(content). Build()). Build() - patchResp, err := s.client.Im.V1.Message.Patch(ctx, patchReq) + patchResp, err := s.client.Im.Message.Patch(ctx, patchReq) if err != nil { return err } diff --git a/internal/channel/adapters/feishu/webhook_handler.go b/internal/channel/adapters/feishu/webhook_handler.go index 79c272fb..52115e40 100644 --- a/internal/channel/adapters/feishu/webhook_handler.go +++ b/internal/channel/adapters/feishu/webhook_handler.go @@ -61,7 +61,7 @@ func (h *WebhookHandler) Register(e *echo.Echo) { } // HandleProbe responds to health/probe requests on the webhook URL. -func (h *WebhookHandler) HandleProbe(c echo.Context) error { +func (*WebhookHandler) HandleProbe(c echo.Context) error { return c.String(http.StatusOK, "ok") } @@ -102,15 +102,16 @@ func (h *WebhookHandler) Handle(c echo.Context) error { botOpenID := h.adapter.resolveBotOpenID(context.WithoutCancel(c.Request().Context()), cfg) + reqCtx := c.Request().Context() eventDispatcher := dispatcher.NewEventDispatcher(feishuCfg.VerificationToken, feishuCfg.EncryptKey) eventDispatcher.OnP2MessageReceiveV1(func(_ context.Context, event *larkim.P2MessageReceiveV1) error { - msg := extractFeishuInbound(event, botOpenID) + msg := extractFeishuInbound(event, botOpenID, h.adapter.logger) if strings.TrimSpace(msg.Message.PlainText()) == "" && len(msg.Message.Attachments) == 0 { return nil } - h.adapter.enrichSenderProfile(context.WithoutCancel(c.Request().Context()), cfg, event, &msg) + h.adapter.enrichSenderProfile(reqCtx, cfg, event, &msg) msg.BotID = cfg.BotID - return h.manager.HandleInbound(context.WithoutCancel(c.Request().Context()), cfg, msg) + return h.manager.HandleInbound(reqCtx, cfg, msg) }) resp := eventDispatcher.Handle(c.Request().Context(), &larkevent.EventReq{ diff --git a/internal/channel/adapters/feishu/webhook_handler_test.go b/internal/channel/adapters/feishu/webhook_handler_test.go index 05da4375..ee593ea6 100644 --- a/internal/channel/adapters/feishu/webhook_handler_test.go +++ b/internal/channel/adapters/feishu/webhook_handler_test.go @@ -2,6 +2,7 @@ package feishu import ( "context" + "errors" "net/http" "net/http/httptest" "strings" @@ -17,7 +18,7 @@ type fakeWebhookStore struct { err error } -func (s *fakeWebhookStore) ListConfigsByType(ctx context.Context, channelType channel.ChannelType) ([]channel.ChannelConfig, error) { +func (s *fakeWebhookStore) ListConfigsByType(_ context.Context, _ channel.ChannelType) ([]channel.ChannelConfig, error) { if s.err != nil { return nil, s.err } @@ -32,7 +33,7 @@ type fakeWebhookManager struct { err error } -func (m *fakeWebhookManager) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) error { +func (m *fakeWebhookManager) HandleInbound(_ context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) error { m.calls = append(m.calls, struct { cfg channel.ChannelConfig msg channel.InboundMessage @@ -237,7 +238,8 @@ func TestWebhookHandler_EventCallbackRejectsInvalidTokenWhenEncryptKeyMissing(t if err == nil { t.Fatal("expected unauthorized error") } - he, ok := err.(*echo.HTTPError) + he := &echo.HTTPError{} + ok := errors.As(err, &he) if !ok { t.Fatalf("expected HTTPError, got %T", err) } @@ -282,7 +284,8 @@ func TestWebhookHandler_EventCallbackRequiresVerificationTokenWhenEncryptKeyMiss if err == nil { t.Fatal("expected forbidden error") } - he, ok := err.(*echo.HTTPError) + he := &echo.HTTPError{} + ok := errors.As(err, &he) if !ok { t.Fatalf("expected HTTPError, got %T", err) } @@ -327,7 +330,8 @@ func TestWebhookHandler_RejectsOversizedBody(t *testing.T) { if err == nil { t.Fatal("expected payload-too-large error") } - he, ok := err.(*echo.HTTPError) + he := &echo.HTTPError{} + ok := errors.As(err, &he) if !ok { t.Fatalf("expected HTTPError, got %T", err) } diff --git a/internal/channel/adapters/local/broadcaster_test.go b/internal/channel/adapters/local/broadcaster_test.go index 4a909ad6..a03256e2 100644 --- a/internal/channel/adapters/local/broadcaster_test.go +++ b/internal/channel/adapters/local/broadcaster_test.go @@ -55,7 +55,7 @@ func TestRouteHubBroadcaster_EmptyBotID(t *testing.T) { } } -func TestRouteHubBroadcaster_NilHub(t *testing.T) { +func TestRouteHubBroadcaster_NilHub(_ *testing.T) { broadcaster := NewRouteHubBroadcaster(nil) // Must not panic. broadcaster.OnStreamEvent(context.Background(), "bot1", "telegram", channel.StreamEvent{ diff --git a/internal/channel/adapters/local/cli.go b/internal/channel/adapters/local/cli.go index 3b026fb6..8a39aa34 100644 --- a/internal/channel/adapters/local/cli.go +++ b/internal/channel/adapters/local/cli.go @@ -2,7 +2,7 @@ package local import ( "context" - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" @@ -19,12 +19,12 @@ func NewCLIAdapter(hub *RouteHub) *CLIAdapter { } // Type returns the CLI channel type. -func (a *CLIAdapter) Type() channel.ChannelType { +func (*CLIAdapter) Type() channel.ChannelType { return CLIType } // Descriptor returns the CLI channel metadata. -func (a *CLIAdapter) Descriptor() channel.Descriptor { +func (*CLIAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: CLIType, DisplayName: "CLI", @@ -46,29 +46,29 @@ func (a *CLIAdapter) Descriptor() channel.Descriptor { } // Send publishes an outbound message to the CLI route hub. -func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { +func (a *CLIAdapter) Send(_ context.Context, _ channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { - return fmt.Errorf("cli hub not configured") + return errors.New("cli hub not configured") } target := strings.TrimSpace(msg.Target) if target == "" { - return fmt.Errorf("cli target is required") + return errors.New("cli target is required") } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } a.hub.Publish(target, msg) return nil } // OpenStream opens a local stream session bound to the target route. -func (a *CLIAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (a *CLIAdapter) OpenStream(ctx context.Context, _ channel.ChannelConfig, target string, _ channel.StreamOptions) (channel.OutboundStream, error) { if a.hub == nil { - return nil, fmt.Errorf("cli hub not configured") + return nil, errors.New("cli hub not configured") } target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("cli target is required") + return nil, errors.New("cli target is required") } select { case <-ctx.Done(): diff --git a/internal/channel/adapters/local/hub.go b/internal/channel/adapters/local/hub.go index 0fef9edb..935a2f93 100644 --- a/internal/channel/adapters/local/hub.go +++ b/internal/channel/adapters/local/hub.go @@ -2,7 +2,7 @@ package local import ( "context" - "fmt" + "errors" "sync" "sync/atomic" @@ -107,10 +107,10 @@ func newLocalOutboundStream(hub *RouteHub, target string) channel.OutboundStream func (s *localOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { if s == nil || s.hub == nil { - return fmt.Errorf("route hub not configured") + return errors.New("route hub not configured") } if s.closed.Load() { - return fmt.Errorf("stream is closed") + return errors.New("stream is closed") } select { case <-ctx.Done(): diff --git a/internal/channel/adapters/local/web.go b/internal/channel/adapters/local/web.go index 70309748..0a0f9cd5 100644 --- a/internal/channel/adapters/local/web.go +++ b/internal/channel/adapters/local/web.go @@ -2,7 +2,7 @@ package local import ( "context" - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" @@ -19,12 +19,12 @@ func NewWebAdapter(hub *RouteHub) *WebAdapter { } // Type returns the Web channel type. -func (a *WebAdapter) Type() channel.ChannelType { +func (*WebAdapter) Type() channel.ChannelType { return WebType } // Descriptor returns the Web channel metadata. -func (a *WebAdapter) Descriptor() channel.Descriptor { +func (*WebAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: WebType, DisplayName: "Web", @@ -46,29 +46,29 @@ func (a *WebAdapter) Descriptor() channel.Descriptor { } // Send publishes an outbound message to the Web route hub. -func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { +func (a *WebAdapter) Send(_ context.Context, _ channel.ChannelConfig, msg channel.OutboundMessage) error { if a.hub == nil { - return fmt.Errorf("web hub not configured") + return errors.New("web hub not configured") } target := strings.TrimSpace(msg.Target) if target == "" { - return fmt.Errorf("web target is required") + return errors.New("web target is required") } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } a.hub.Publish(target, msg) return nil } // OpenStream opens a local stream session bound to the target route. -func (a *WebAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (a *WebAdapter) OpenStream(ctx context.Context, _ channel.ChannelConfig, target string, _ channel.StreamOptions) (channel.OutboundStream, error) { if a.hub == nil { - return nil, fmt.Errorf("web hub not configured") + return nil, errors.New("web hub not configured") } target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("web target is required") + return nil, errors.New("web target is required") } select { case <-ctx.Done(): diff --git a/internal/channel/adapters/telegram/config.go b/internal/channel/adapters/telegram/config.go index 20d22033..a2b0e8e4 100644 --- a/internal/channel/adapters/telegram/config.go +++ b/internal/channel/adapters/telegram/config.go @@ -1,7 +1,7 @@ package telegram import ( - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" @@ -90,7 +90,7 @@ func resolveTarget(raw map[string]any) (string, error) { } return name, nil } - return "", fmt.Errorf("telegram binding is incomplete") + return "", errors.New("telegram binding is incomplete") } func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { @@ -132,7 +132,7 @@ func buildUserConfig(identity channel.Identity) map[string]any { func parseConfig(raw map[string]any) (Config, error) { token := strings.TrimSpace(channel.ReadString(raw, "botToken", "bot_token")) if token == "" { - return Config{}, fmt.Errorf("telegram botToken is required") + return Config{}, errors.New("telegram botToken is required") } apiBaseURL := strings.TrimSpace(channel.ReadString(raw, "apiBaseURL", "api_base_url")) return Config{BotToken: token, APIBaseURL: apiBaseURL}, nil @@ -143,7 +143,7 @@ func parseUserConfig(raw map[string]any) (UserConfig, error) { userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) chatID := strings.TrimSpace(channel.ReadString(raw, "chatId", "chat_id")) if username == "" && userID == "" && chatID == "" { - return UserConfig{}, fmt.Errorf("telegram user config requires username, user_id, or chat_id") + return UserConfig{}, errors.New("telegram user config requires username, user_id, or chat_id") } return UserConfig{ Username: username, diff --git a/internal/channel/adapters/telegram/directory.go b/internal/channel/adapters/telegram/directory.go index 26220b53..92eece7a 100644 --- a/internal/channel/adapters/telegram/directory.go +++ b/internal/channel/adapters/telegram/directory.go @@ -2,6 +2,7 @@ package telegram import ( "context" + "errors" "fmt" "strconv" "strings" @@ -27,17 +28,17 @@ func directoryLimit(n int) int { } // ListPeers returns users the bot can reach. Telegram Bot API does not provide a list of users; returns empty. -func (a *TelegramAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*TelegramAdapter) ListPeers(_ context.Context, _ channel.ChannelConfig, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } // ListGroups returns chats the bot is in. Telegram Bot API does not provide a list of chats; returns empty. -func (a *TelegramAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*TelegramAdapter) ListGroups(_ context.Context, _ channel.ChannelConfig, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } // ListGroupMembers returns group managers for the given group (Telegram only exposes this list, not all members). -func (a *TelegramAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *TelegramAdapter) ListGroupMembers(_ context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err @@ -96,7 +97,7 @@ func (a *TelegramAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelC } } -func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { +func (a *TelegramAdapter) resolveTelegramUser(_ context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { chatID, userID := parseTelegramUserInput(input) if chatID == 0 && userID == 0 { return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry user: invalid input %q", input) @@ -114,7 +115,7 @@ func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: %w", err) } if member.User == nil { - return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: empty user") + return channel.DirectoryEntry{}, errors.New("telegram get chat member: empty user") } return a.telegramUserToEntryWithAvatar(bot, member.User), nil } @@ -148,7 +149,7 @@ func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi }, nil } -func (a *TelegramAdapter) resolveTelegramGroup(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { +func (a *TelegramAdapter) resolveTelegramGroup(_ context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { chatID, superGroupUsername := parseTelegramChatInput(input) if chatID == 0 && superGroupUsername == "" { return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry group: invalid input %q", input) diff --git a/internal/channel/adapters/telegram/logger.go b/internal/channel/adapters/telegram/logger.go index e2fa7645..ca97c112 100644 --- a/internal/channel/adapters/telegram/logger.go +++ b/internal/channel/adapters/telegram/logger.go @@ -3,17 +3,46 @@ package telegram import ( "fmt" "log/slog" + "sync" ) // slogBotLogger adapts slog.Logger to tgbotapi.BotLogger so library logs go through slog. type slogBotLogger struct { + mu sync.RWMutex log *slog.Logger } +func newSlogBotLogger(log *slog.Logger) *slogBotLogger { + logger := &slogBotLogger{} + logger.SetLogger(log) + return logger +} + +func (s *slogBotLogger) SetLogger(log *slog.Logger) { + s.mu.Lock() + defer s.mu.Unlock() + if log == nil { + log = slog.Default() + } + s.log = log +} + +func (s *slogBotLogger) current() *slog.Logger { + s.mu.RLock() + defer s.mu.RUnlock() + if s.log == nil { + return slog.Default() + } + return s.log +} + func (s *slogBotLogger) Println(v ...interface{}) { - s.log.Warn(fmt.Sprint(v...)) + s.current().Warn("telegram bot sdk log", slog.String("message", fmt.Sprint(v...))) } func (s *slogBotLogger) Printf(format string, v ...interface{}) { - s.log.Warn(fmt.Sprintf(format, v...)) + s.current().Warn( + "telegram bot sdk log", + slog.String("message", fmt.Sprintf(format, v...)), + ) } diff --git a/internal/channel/adapters/telegram/markdown.go b/internal/channel/adapters/telegram/markdown.go index a3fba424..18f424d7 100644 --- a/internal/channel/adapters/telegram/markdown.go +++ b/internal/channel/adapters/telegram/markdown.go @@ -16,14 +16,13 @@ const ( ) var ( - reCodeBlockFence = regexp.MustCompile("(?s)```(\\w*)\\n?(.*?)```") - reInlineCode = regexp.MustCompile("`([^`\\n]+?)`") - reBold = regexp.MustCompile(`\*\*(.+?)\*\*`) - reStrike = regexp.MustCompile(`~~(.+?)~~`) - reLink = regexp.MustCompile(`\[([^\]]+?)\]\(([^)]+?)\)`) - reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) - reListBullet = regexp.MustCompile(`(?m)^(\s*)[-+]\s`) - reItalic = regexp.MustCompile(`\*([^*\n]+?)\*`) + reInlineCode = regexp.MustCompile("`([^`\\n]+?)`") + reBold = regexp.MustCompile(`\*\*(.+?)\*\*`) + reStrike = regexp.MustCompile(`~~(.+?)~~`) + reLink = regexp.MustCompile(`\[([^\]]+?)\]\(([^)]+?)\)`) + reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) + reListBullet = regexp.MustCompile(`(?m)^(\s*)[-+]\s`) + reItalic = regexp.MustCompile(`\*([^*\n]+?)\*`) ) // formatTelegramOutput converts standard markdown to Telegram-compatible HTML @@ -64,7 +63,7 @@ func markdownToTelegramHTML(text string) string { lang, code := extractCodeBlockLang(seg) escaped := telegramEscapeHTML(strings.TrimRight(code, "\n")) if lang != "" { - buf.WriteString(fmt.Sprintf("
%s
", lang, escaped)) + fmt.Fprintf(&buf, "
%s
", lang, escaped) } else { buf.WriteString("
" + escaped + "
") } diff --git a/internal/channel/adapters/telegram/stream.go b/internal/channel/adapters/telegram/stream.go index 42eb0341..f6b5984b 100644 --- a/internal/channel/adapters/telegram/stream.go +++ b/internal/channel/adapters/telegram/stream.go @@ -2,7 +2,7 @@ package telegram import ( "context" - "fmt" + "errors" "log/slog" "strings" "sync" @@ -14,10 +14,12 @@ import ( "github.com/memohai/memoh/internal/channel" ) -const telegramStreamEditThrottle = 5000 * time.Millisecond -const telegramDraftThrottle = 300 * time.Millisecond -const telegramStreamToolHintText = "Calling tools..." -const telegramStreamPendingSuffix = "\n……" +const ( + telegramStreamEditThrottle = 5000 * time.Millisecond + telegramDraftThrottle = 300 * time.Millisecond + telegramStreamToolHintText = "Calling tools..." + telegramStreamPendingSuffix = "\n……" +) var testEditFunc func(bot *tgbotapi.BotAPI, chatID int64, msgID int, text string, parseMode string) error @@ -38,7 +40,7 @@ type telegramOutboundStream struct { lastEditedAt time.Time } -func (s *telegramOutboundStream) getBot(ctx context.Context) (bot *tgbotapi.BotAPI, err error) { +func (s *telegramOutboundStream) getBot(_ context.Context) (bot *tgbotapi.BotAPI, err error) { telegramCfg, err := parseConfig(s.cfg.Credentials) if err != nil { return nil, err @@ -75,7 +77,9 @@ func (s *telegramOutboundStream) ensureStreamMessage(ctx context.Context, text s s.mu.Lock() go func() { if err := s.refreshTypingAction(ctx); err != nil { - slog.Debug("refresh typing action failed", slog.Any("err", err)) + if s.adapter != nil && s.adapter.logger != nil { + s.adapter.logger.Debug("refresh typing action failed", slog.Any("error", err)) + } } }() if s.streamMsgID != 0 { @@ -263,10 +267,10 @@ func (s *telegramOutboundStream) sendPermanentMessage(ctx context.Context, text func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { if s == nil || s.adapter == nil { - return fmt.Errorf("telegram stream not configured") + return errors.New("telegram stream not configured") } if s.closed.Load() { - return fmt.Errorf("telegram stream is closed") + return errors.New("telegram stream is closed") } select { case <-ctx.Done(): @@ -285,7 +289,9 @@ func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamE // In draft mode, send buffered text as a permanent message before tool execution. if bufText != "" { if err := s.sendPermanentMessage(ctx, bufText, ""); err != nil { - slog.Warn("telegram: draft permanent message failed", slog.Any("error", err)) + if s.adapter != nil && s.adapter.logger != nil { + s.adapter.logger.Warn("telegram: draft permanent message failed", slog.Any("error", err)) + } } } } else if hasMsg && bufText != "" { @@ -322,11 +328,13 @@ func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamE } for _, att := range event.Attachments { if sendErr := sendTelegramAttachmentWithAssets(ctx, bot, s.target, att, "", replyTo, "", s.adapter.assets); sendErr != nil { - slog.Warn("telegram: stream attachment send failed", - slog.String("config_id", s.cfg.ID), - slog.String("type", string(att.Type)), - slog.Any("error", sendErr), - ) + if s.adapter != nil && s.adapter.logger != nil { + s.adapter.logger.Warn("telegram: stream attachment send failed", + slog.String("config_id", s.cfg.ID), + slog.String("type", string(att.Type)), + slog.Any("error", sendErr), + ) + } } } return nil @@ -381,14 +389,20 @@ func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamE if bufText != "" { if s.isPrivateChat { if err := s.sendPermanentMessage(ctx, bufText, ""); err != nil { - slog.Warn("telegram: draft final permanent message failed", slog.Any("error", err)) + if s.adapter != nil && s.adapter.logger != nil { + s.adapter.logger.Warn("telegram: draft final permanent message failed", slog.Any("error", err)) + } } } else { if err := s.ensureStreamMessage(ctx, bufText); err != nil { - slog.Warn("telegram: ensure stream message failed", slog.Any("error", err)) + if s.adapter != nil && s.adapter.logger != nil { + s.adapter.logger.Warn("telegram: ensure stream message failed", slog.Any("error", err)) + } } if err := s.editStreamMessageFinal(ctx, bufText); err != nil { - slog.Warn("telegram: edit stream message failed", slog.Any("error", err)) + if s.adapter != nil && s.adapter.logger != nil { + s.adapter.logger.Warn("telegram: edit stream message failed", slog.Any("error", err)) + } } } } diff --git a/internal/channel/adapters/telegram/stream_test.go b/internal/channel/adapters/telegram/stream_test.go index e6c92538..5121ec01 100644 --- a/internal/channel/adapters/telegram/stream_test.go +++ b/internal/channel/adapters/telegram/stream_test.go @@ -2,11 +2,13 @@ package telegram import ( "context" + "errors" "strings" "testing" "time" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/memohai/memoh/internal/channel" ) @@ -112,7 +114,7 @@ func TestTelegramOutboundStream_CloseContextCanceled(t *testing.T) { cancel() err := s.Close(ctx) - if err != context.Canceled { + if !errors.Is(err, context.Canceled) { t.Fatalf("Close with canceled context should return context.Canceled: %v", err) } } @@ -183,8 +185,6 @@ func TestEditStreamMessage_NoEditWhenThrottled(t *testing.T) { } func TestEditStreamMessage_429SetsBackoffAndReturnsNil(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) before := time.Now().Add(-time.Minute) s := &telegramOutboundStream{ @@ -231,8 +231,6 @@ func TestEditStreamMessage_429SetsBackoffAndReturnsNil(t *testing.T) { } func TestEditStreamMessageFinal_Success(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, @@ -343,8 +341,6 @@ func TestSendDraft_EmptyTextSkip(t *testing.T) { } func TestSendDraft_Success(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, @@ -391,8 +387,6 @@ func TestSendDraft_Success(t *testing.T) { } func TestSendDraft_429Backoff(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) before := time.Now().Add(-time.Minute) s := &telegramOutboundStream{ @@ -435,8 +429,6 @@ func TestSendDraft_429Backoff(t *testing.T) { } func TestDraftMode_DeltaUsesSendDraft(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, @@ -500,8 +492,6 @@ func TestDraftMode_PhaseEndTextIsNoOp(t *testing.T) { } func TestDraftMode_ToolCallStartSendsPermanentMessage(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, @@ -525,7 +515,7 @@ func TestDraftMode_ToolCallStartSendsPermanentMessage(t *testing.T) { sentText = text return 123, 1, nil } - sendEditForTest = func(_ *tgbotapi.BotAPI, edit tgbotapi.EditMessageTextConfig) error { + sendEditForTest = func(_ *tgbotapi.BotAPI, _ tgbotapi.EditMessageTextConfig) error { t.Error("editMessage should not be called in draft mode") return nil } @@ -557,8 +547,6 @@ func TestDraftMode_ToolCallStartSendsPermanentMessage(t *testing.T) { } func TestDraftMode_FinalEmptyBufferSkipsDuplicate(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, @@ -602,8 +590,6 @@ func TestDraftMode_FinalEmptyBufferSkipsDuplicate(t *testing.T) { // responses), only the first one sends the buffer text as a permanent message. // Subsequent finals find the buffer empty and skip sending. func TestDraftMode_MultipleFinalEventsOnlyOneSend(t *testing.T) { - t.Parallel() - adapter := NewTelegramAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index 5d61b494..b62be257 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -22,8 +22,15 @@ import ( "github.com/memohai/memoh/internal/media" ) -const telegramMaxMessageLength = 4096 -const telegramMediaGroupCollectWindow = 700 * time.Millisecond +const ( + telegramMaxMessageLength = 4096 + telegramMediaGroupCollectWindow = 700 * time.Millisecond +) + +var ( + telegramBotLogger = newSlogBotLogger(nil) + telegramLoggerInitOnce sync.Once +) type telegramMediaGroupBuffer struct { messages []*tgbotapi.Message @@ -54,10 +61,17 @@ func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { bots: make(map[string]*tgbotapi.BotAPI), fileEndpoints: make(map[string]string), } - _ = tgbotapi.SetLogger(&slogBotLogger{log: adapter.logger}) + initTelegramBotLogger(adapter.logger) return adapter } +func initTelegramBotLogger(log *slog.Logger) { + telegramLoggerInitOnce.Do(func() { + _ = tgbotapi.SetLogger(telegramBotLogger) + }) + telegramBotLogger.SetLogger(log) +} + // SetAssetOpener injects the media asset reader for storage-first file delivery. func (a *TelegramAdapter) SetAssetOpener(opener assetOpener) { a.assets = opener @@ -105,16 +119,32 @@ func (a *TelegramAdapter) getFileDirectURL(bot *tgbotapi.BotAPI, fileID string) if endpoint == "" { endpoint = tgbotapi.FileEndpoint } - return fmt.Sprintf(endpoint, bot.Token, file.FilePath), nil + return formatTelegramFileURL(endpoint, bot.Token, file.FilePath), nil +} + +func formatTelegramFileURL(endpoint, token, filePath string) string { + placeholderCount := strings.Count(endpoint, "%s") + switch { + case placeholderCount >= 2: + return fmt.Sprintf(endpoint, token, filePath) + case placeholderCount == 1: + return fmt.Sprintf(endpoint, filePath) + default: + base := strings.TrimRight(strings.TrimSpace(endpoint), "/") + if base == "" { + return filePath + } + return base + "/" + strings.TrimLeft(filePath, "/") + } } // Type returns the Telegram channel type. -func (a *TelegramAdapter) Type() channel.ChannelType { +func (*TelegramAdapter) Type() channel.ChannelType { return Type } // Descriptor returns the Telegram channel metadata. -func (a *TelegramAdapter) Descriptor() channel.Descriptor { +func (*TelegramAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: Type, DisplayName: "Telegram", @@ -162,32 +192,32 @@ func (a *TelegramAdapter) Descriptor() channel.Descriptor { } // NormalizeConfig validates and normalizes a Telegram channel configuration map. -func (a *TelegramAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (*TelegramAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) } // NormalizeUserConfig validates and normalizes a Telegram user-binding configuration map. -func (a *TelegramAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func (*TelegramAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return normalizeUserConfig(raw) } // NormalizeTarget normalizes a Telegram delivery target string. -func (a *TelegramAdapter) NormalizeTarget(raw string) string { +func (*TelegramAdapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } // ResolveTarget derives a delivery target from a Telegram user-binding configuration. -func (a *TelegramAdapter) ResolveTarget(userConfig map[string]any) (string, error) { +func (*TelegramAdapter) ResolveTarget(userConfig map[string]any) (string, error) { return resolveTarget(userConfig) } // MatchBinding reports whether a Telegram user binding matches the given criteria. -func (a *TelegramAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { +func (*TelegramAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { return matchBinding(config, criteria) } // BuildUserConfig constructs a Telegram user-binding config from an Identity. -func (a *TelegramAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (*TelegramAdapter) BuildUserConfig(identity channel.Identity) map[string]any { return buildUserConfig(identity) } @@ -220,8 +250,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig mediaGroups := make(map[string]*telegramMediaGroupBuffer) var mediaGroupsMu sync.Mutex - var flushMediaGroup func(groupKey string) - flushMediaGroup = func(groupKey string) { + flushMediaGroup := func(groupKey string) { var batch []*tgbotapi.Message mediaGroupsMu.Lock() buffer, ok := mediaGroups[groupKey] @@ -449,7 +478,7 @@ func (a *TelegramAdapter) buildTelegramMediaGroupInboundMessage( func (a *TelegramAdapter) toInboundTelegramMessage( bot *tgbotapi.BotAPI, - cfg channel.ChannelConfig, + _ channel.ChannelConfig, raw *tgbotapi.Message, text string, attachments []channel.Attachment, @@ -559,14 +588,14 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m } to := strings.TrimSpace(msg.Target) if to == "" { - return fmt.Errorf("telegram target is required") + return errors.New("telegram target is required") } bot, err := a.getOrCreateBot(telegramCfg, cfg.ID) if err != nil { return err } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } text := strings.TrimSpace(msg.Message.PlainText()) text, parseMode := formatTelegramOutput(text, msg.Message.Format) @@ -606,7 +635,7 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m func (a *TelegramAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("telegram target is required") + return nil, errors.New("telegram target is required") } select { case <-ctx.Done(): @@ -728,7 +757,7 @@ func sendTelegramTextReturnMessage(bot *tgbotapi.BotAPI, target string, text str } else { chatID, err = strconv.ParseInt(target, 10, 64) if err != nil { - return 0, 0, fmt.Errorf("telegram target must be @username or chat_id") + return 0, 0, errors.New("telegram target must be @username or chat_id") } message := tgbotapi.NewMessage(chatID, text) message.ParseMode = parseMode @@ -779,7 +808,7 @@ func sendTelegramDraft(bot *tgbotapi.BotAPI, chatID int64, draftID int, text str return sendDraftForTest(bot, chatID, draftID, text, parseMode) } params := tgbotapi.Params{} - params.AddFirstValid("chat_id", chatID) + _ = params.AddFirstValid("chat_id", chatID) params.AddNonZero("draft_id", draftID) params.AddNonEmpty("text", text) params.AddNonEmpty("parse_mode", parseMode) @@ -824,18 +853,14 @@ func sendTelegramAttachmentWithAssets(ctx context.Context, bot *tgbotapi.BotAPI, return sendTelegramAttachmentImpl(ctx, bot, target, att, caption, replyTo, parseMode, opener) } -func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Attachment, caption string, replyTo int, parseMode string) error { - return sendTelegramAttachmentImpl(context.Background(), bot, target, att, caption, replyTo, parseMode, nil) -} - -func sendTelegramAttachmentImpl(_ context.Context, bot *tgbotapi.BotAPI, target string, att channel.Attachment, caption string, replyTo int, parseMode string, opener assetOpener) error { +func sendTelegramAttachmentImpl(ctx context.Context, bot *tgbotapi.BotAPI, target string, att channel.Attachment, caption string, replyTo int, parseMode string, opener assetOpener) error { urlRef := strings.TrimSpace(att.URL) keyRef := strings.TrimSpace(att.PlatformKey) sourcePlatform := strings.TrimSpace(att.SourcePlatform) base64Ref := strings.TrimSpace(att.Base64) assetID := strings.TrimSpace(att.ContentHash) if urlRef == "" && keyRef == "" && base64Ref == "" && assetID == "" { - return fmt.Errorf("attachment reference is required") + return errors.New("attachment reference is required") } if strings.TrimSpace(caption) == "" && strings.TrimSpace(att.Caption) != "" { caption = strings.TrimSpace(att.Caption) @@ -846,7 +871,7 @@ func sendTelegramAttachmentImpl(_ context.Context, bot *tgbotapi.BotAPI, target botID = bid } } - file, err := resolveTelegramFile(urlRef, keyRef, base64Ref, sourcePlatform, att, assetID, botID, opener) + file, err := resolveTelegramFile(ctx, urlRef, keyRef, base64Ref, sourcePlatform, att, assetID, botID, opener) if err != nil { return err } @@ -859,7 +884,7 @@ func sendTelegramAttachmentImpl(_ context.Context, bot *tgbotapi.BotAPI, target } else { chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return fmt.Errorf("telegram target must be @username or chat_id") + return errors.New("telegram target must be @username or chat_id") } photo = tgbotapi.NewPhoto(chatID, file) } @@ -882,7 +907,7 @@ func sendTelegramAttachmentImpl(_ context.Context, bot *tgbotapi.BotAPI, target } else { chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return fmt.Errorf("telegram target must be @username or chat_id") + return errors.New("telegram target must be @username or chat_id") } document = tgbotapi.NewDocument(chatID, file) } @@ -948,12 +973,12 @@ func sendTelegramAttachmentImpl(_ context.Context, bot *tgbotapi.BotAPI, target // resolveTelegramFile determines the best tgbotapi.RequestFileData for an attachment. // Priority: PlatformKey > ContentHash (storage) > public URL > base64 data URL. -func resolveTelegramFile(urlRef, keyRef, base64Ref, sourcePlatform string, att channel.Attachment, assetID, botID string, opener assetOpener) (tgbotapi.RequestFileData, error) { +func resolveTelegramFile(ctx context.Context, urlRef, keyRef, base64Ref, sourcePlatform string, att channel.Attachment, assetID, botID string, opener assetOpener) (tgbotapi.RequestFileData, error) { if keyRef != "" && (sourcePlatform == "" || strings.EqualFold(sourcePlatform, Type.String())) { return tgbotapi.FileID(keyRef), nil } if assetID != "" && opener != nil { - reader, asset, err := opener.Open(context.Background(), botID, assetID) + reader, asset, err := opener.Open(ctx, botID, assetID) if err == nil { data, readErr := io.ReadAll(io.LimitReader(reader, media.MaxAssetBytes+1)) _ = reader.Close() @@ -987,7 +1012,7 @@ func resolveTelegramFile(urlRef, keyRef, base64Ref, sourcePlatform string, att c if urlRef != "" { return tgbotapi.FileURL(urlRef), nil } - return nil, fmt.Errorf("no usable attachment reference for telegram") + return nil, errors.New("no usable attachment reference for telegram") } func decodeDataURLBytes(dataURL string) ([]byte, error) { @@ -1138,7 +1163,7 @@ func buildTelegramAudio(target string, file tgbotapi.RequestFileData) (tgbotapi. } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.AudioConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.AudioConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewAudio(chatID, file), nil } @@ -1151,7 +1176,7 @@ func buildTelegramVoice(target string, file tgbotapi.RequestFileData) (tgbotapi. } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.VoiceConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.VoiceConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewVoice(chatID, file), nil } @@ -1164,7 +1189,7 @@ func buildTelegramVideo(target string, file tgbotapi.RequestFileData) (tgbotapi. } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.VideoConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.VideoConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewVideo(chatID, file), nil } @@ -1177,7 +1202,7 @@ func buildTelegramAnimation(target string, file tgbotapi.RequestFileData) (tgbot } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.AnimationConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.AnimationConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewAnimation(chatID, file), nil } @@ -1355,7 +1380,7 @@ func (a *TelegramAdapter) buildTelegramAttachment(bot *tgbotapi.BotAPI, attType func (a *TelegramAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) { fileID := strings.TrimSpace(attachment.PlatformKey) if fileID == "" && strings.TrimSpace(attachment.URL) == "" { - return channel.AttachmentPayload{}, fmt.Errorf("telegram attachment requires platform_key or url") + return channel.AttachmentPayload{}, errors.New("telegram attachment requires platform_key or url") } telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { @@ -1377,7 +1402,7 @@ func (a *TelegramAdapter) ResolveAttachment(ctx context.Context, cfg channel.Cha return channel.AttachmentPayload{}, fmt.Errorf("build download request: %w", err) } client := &http.Client{Timeout: 60 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: URL is a Telegram file download URL from the Telegram Bot API if err != nil { return channel.AttachmentPayload{}, fmt.Errorf("download attachment: %w", err) } @@ -1416,7 +1441,7 @@ func (a *TelegramAdapter) ResolveAttachment(ctx context.Context, cfg channel.Cha } // DiscoverSelf retrieves the bot's own identity from the Telegram platform. -func (a *TelegramAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { +func (a *TelegramAdapter) DiscoverSelf(_ context.Context, credentials map[string]any) (map[string]any, string, error) { cfg, err := parseConfig(credentials) if err != nil { return nil, "", err @@ -1503,7 +1528,7 @@ func truncateTelegramText(text string) string { } // ProcessingStarted sends a "typing" chat action to indicate processing. -func (a *TelegramAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (a *TelegramAdapter) ProcessingStarted(_ context.Context, cfg channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { chatID := strings.TrimSpace(info.ReplyTarget) if chatID == "" { return channel.ProcessingStatusHandle{}, nil @@ -1523,12 +1548,12 @@ func (a *TelegramAdapter) ProcessingStarted(ctx context.Context, cfg channel.Cha } // ProcessingCompleted is a no-op for Telegram (typing indicator clears automatically). -func (a *TelegramAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (*TelegramAdapter) ProcessingCompleted(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle) error { return nil } // ProcessingFailed is a no-op for Telegram (typing indicator clears automatically). -func (a *TelegramAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (*TelegramAdapter) ProcessingFailed(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle, _ error) error { return nil } @@ -1561,7 +1586,7 @@ func clearTelegramReaction(bot *tgbotapi.BotAPI, chatID, messageID string) error } // React adds an emoji reaction to a message (implements channel.Reactor). -func (a *TelegramAdapter) React(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { +func (a *TelegramAdapter) React(_ context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { return err @@ -1575,7 +1600,7 @@ func (a *TelegramAdapter) React(ctx context.Context, cfg channel.ChannelConfig, // Unreact removes the bot's reaction from a message (implements channel.Reactor). // The emoji parameter is ignored; Telegram clears all bot reactions at once. -func (a *TelegramAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, _ string) error { +func (a *TelegramAdapter) Unreact(_ context.Context, cfg channel.ChannelConfig, target string, messageID string, _ string) error { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { return err diff --git a/internal/channel/adapters/telegram/telegram_test.go b/internal/channel/adapters/telegram/telegram_test.go index 186abd5e..42e8b172 100644 --- a/internal/channel/adapters/telegram/telegram_test.go +++ b/internal/channel/adapters/telegram/telegram_test.go @@ -2,6 +2,7 @@ package telegram import ( "context" + "errors" "fmt" "io" "strings" @@ -10,6 +11,7 @@ import ( "unicode/utf8" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/memohai/memoh/internal/channel" "github.com/memohai/memoh/internal/media" ) @@ -513,7 +515,7 @@ func TestIsTelegramMessageNotModified(t *testing.T) { want bool }{ {"nil", nil, false}, - {"plain error", fmt.Errorf("network error"), false}, + {"plain error", errors.New("network error"), false}, {"other api error", tgbotapi.Error{Code: 400, Message: "Bad Request: chat not found"}, false}, {"message is not modified", tgbotapi.Error{Code: 400, Message: productionMessageNotModified}, true}, {"production exact", tgbotapi.Error{Code: 400, Message: productionMessageNotModified}, true}, @@ -707,7 +709,7 @@ func TestProcessingFailed_DelegatesToCompleted(t *testing.T) { adapter := NewTelegramAdapter(nil) ctx := context.Background() - err := adapter.ProcessingFailed(ctx, channel.ChannelConfig{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{}, channel.ProcessingStatusHandle{}, fmt.Errorf("test")) + err := adapter.ProcessingFailed(ctx, channel.ChannelConfig{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{}, channel.ProcessingStatusHandle{}, errors.New("test")) if err != nil { t.Fatalf("empty handle should be no-op: %v", err) } @@ -717,7 +719,7 @@ func TestResolveTelegramFile_PlatformKey(t *testing.T) { t.Parallel() att := channel.Attachment{Type: channel.AttachmentImage, PlatformKey: "file_id_123"} - file, err := resolveTelegramFile("", "file_id_123", "", "", att, "", "", nil) + file, err := resolveTelegramFile(context.Background(), "", "file_id_123", "", "", att, "", "", nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -730,7 +732,7 @@ func TestResolveTelegramFile_PublicURL(t *testing.T) { t.Parallel() att := channel.Attachment{Type: channel.AttachmentImage} - file, err := resolveTelegramFile("https://example.com/img.png", "", "", "", att, "", "", nil) + file, err := resolveTelegramFile(context.Background(), "https://example.com/img.png", "", "", "", att, "", "", nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -744,7 +746,7 @@ func TestResolveTelegramFile_DataURL(t *testing.T) { dataURL := "data:image/png;base64,iVBORw0KGgo=" att := channel.Attachment{Type: channel.AttachmentImage, Mime: "image/png", Name: "test.png"} - file, err := resolveTelegramFile("", "", dataURL, "", att, "", "", nil) + file, err := resolveTelegramFile(context.Background(), "", "", dataURL, "", att, "", "", nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -764,7 +766,7 @@ func TestResolveTelegramFile_NoReference(t *testing.T) { t.Parallel() att := channel.Attachment{Type: channel.AttachmentImage} - _, err := resolveTelegramFile("", "", "", "", att, "", "", nil) + _, err := resolveTelegramFile(context.Background(), "", "", "", "", att, "", "", nil) if err == nil { t.Fatal("expected error when no reference available") } @@ -775,7 +777,7 @@ func TestResolveTelegramFile_ContainerPathFallsToBase64(t *testing.T) { dataURL := "data:image/jpeg;base64,/9j/4AAQ" att := channel.Attachment{Type: channel.AttachmentImage, Mime: "image/jpeg"} - file, err := resolveTelegramFile("/data/media/image/a.jpg", "", dataURL, "", att, "", "", nil) + file, err := resolveTelegramFile(context.Background(), "/data/media/image/a.jpg", "", dataURL, "", att, "", "", nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -798,7 +800,7 @@ func TestResolveTelegramFile_ContentHash(t *testing.T) { opener := &mockAssetOpener{data: []byte("fake-png-bytes"), mime: "image/png"} att := channel.Attachment{Type: channel.AttachmentImage, ContentHash: "asset-123", Name: "output.png"} - file, err := resolveTelegramFile("", "", "", "", att, "asset-123", "bot-1", opener) + file, err := resolveTelegramFile(context.Background(), "", "", "", "", att, "asset-123", "bot-1", opener) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -819,7 +821,7 @@ func TestResolveTelegramFile_ContentHashPriorityOverURL(t *testing.T) { opener := &mockAssetOpener{data: []byte("from-storage"), mime: "image/jpeg"} att := channel.Attachment{Type: channel.AttachmentImage, ContentHash: "a1"} - file, err := resolveTelegramFile("https://example.com/fallback.jpg", "", "", "", att, "a1", "bot-1", opener) + file, err := resolveTelegramFile(context.Background(), "https://example.com/fallback.jpg", "", "", "", att, "a1", "bot-1", opener) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/internal/channel/capabilities.go b/internal/channel/capabilities.go index 2de4c90c..a1cf379c 100644 --- a/internal/channel/capabilities.go +++ b/internal/channel/capabilities.go @@ -3,20 +3,20 @@ package channel // ChannelCapabilities describes the feature matrix of a channel type. // It is used by the outbound layer to validate message content before delivery. type ChannelCapabilities struct { - Text bool `json:"text"` - Markdown bool `json:"markdown"` - RichText bool `json:"rich_text"` - Attachments bool `json:"attachments"` - Media bool `json:"media"` - Reactions bool `json:"reactions"` - Buttons bool `json:"buttons"` - Reply bool `json:"reply"` - Threads bool `json:"threads"` - Streaming bool `json:"streaming"` - Polls bool `json:"polls"` - Edit bool `json:"edit"` - Unsend bool `json:"unsend"` - NativeCommands bool `json:"native_commands"` - BlockStreaming bool `json:"block_streaming"` - ChatTypes []string `json:"chat_types,omitempty"` + Text bool `json:"text"` + Markdown bool `json:"markdown"` + RichText bool `json:"rich_text"` + Attachments bool `json:"attachments"` + Media bool `json:"media"` + Reactions bool `json:"reactions"` + Buttons bool `json:"buttons"` + Reply bool `json:"reply"` + Threads bool `json:"threads"` + Streaming bool `json:"streaming"` + Polls bool `json:"polls"` + Edit bool `json:"edit"` + Unsend bool `json:"unsend"` + NativeCommands bool `json:"native_commands"` + BlockStreaming bool `json:"block_streaming"` + ChatTypes []string `json:"chat_types,omitempty"` } diff --git a/internal/channel/config_test.go b/internal/channel/config_test.go index b71e232e..f8b61cc4 100644 --- a/internal/channel/config_test.go +++ b/internal/channel/config_test.go @@ -1,7 +1,7 @@ package channel_test import ( - "fmt" + "errors" "testing" "github.com/memohai/memoh/internal/channel" @@ -12,8 +12,8 @@ const testChannelType = channel.ChannelType("test-config") // testConfigAdapter implements Adapter, ConfigNormalizer, TargetResolver, BindingMatcher for tests. type testConfigAdapter struct{} -func (a *testConfigAdapter) Type() channel.ChannelType { return testChannelType } -func (a *testConfigAdapter) Descriptor() channel.Descriptor { +func (*testConfigAdapter) Type() channel.ChannelType { return testChannelType } +func (*testConfigAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: testChannelType, DisplayName: "Test", @@ -35,38 +35,38 @@ func (a *testConfigAdapter) Descriptor() channel.Descriptor { } } -func (a *testConfigAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (*testConfigAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { value := channel.ReadString(raw, "value") if value == "" { - return nil, fmt.Errorf("value is required") + return nil, errors.New("value is required") } return map[string]any{"value": value}, nil } -func (a *testConfigAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func (*testConfigAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { value := channel.ReadString(raw, "user") if value == "" { - return nil, fmt.Errorf("user is required") + return nil, errors.New("user is required") } return map[string]any{"user": value}, nil } -func (a *testConfigAdapter) NormalizeTarget(raw string) string { return raw } +func (*testConfigAdapter) NormalizeTarget(raw string) string { return raw } -func (a *testConfigAdapter) ResolveTarget(raw map[string]any) (string, error) { +func (*testConfigAdapter) ResolveTarget(raw map[string]any) (string, error) { value := channel.ReadString(raw, "target") if value == "" { - return "", fmt.Errorf("target is required") + return "", errors.New("target is required") } return "resolved:" + value, nil } -func (a *testConfigAdapter) MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { +func (*testConfigAdapter) MatchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { value := channel.ReadString(raw, "user") return value != "" && value == criteria.SubjectID } -func (a *testConfigAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (*testConfigAdapter) BuildUserConfig(_ channel.Identity) map[string]any { return map[string]any{} } diff --git a/internal/channel/connection.go b/internal/channel/connection.go index bbda050f..e48cfba1 100644 --- a/internal/channel/connection.go +++ b/internal/channel/connection.go @@ -3,7 +3,6 @@ package channel import ( "context" "errors" - "fmt" "log/slog" "strings" "time" @@ -95,7 +94,7 @@ func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error { _, ok := m.registry.GetReceiver(cfg.ChannelType) if !ok { - m.markConnectionStatus(cfg, false, fmt.Errorf("receiver not available")) + m.markConnectionStatus(cfg, false, errors.New("receiver not available")) return nil } @@ -155,7 +154,7 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error receiver, ok := m.registry.GetReceiver(cfg.ChannelType) if !ok { - m.markConnectionStatus(cfg, false, fmt.Errorf("receiver not available")) + m.markConnectionStatus(cfg, false, errors.New("receiver not available")) return nil } @@ -182,11 +181,8 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error for i := len(m.middlewares) - 1; i >= 0; i-- { handler = m.middlewares[i](handler) } - connectCtx := context.Background() - if ctx != nil { - // Decouple long-lived adapter connections from short-lived request contexts. - connectCtx = context.WithoutCancel(ctx) - } + // Decouple long-lived adapter connections from short-lived request contexts. + connectCtx := context.WithoutCancel(ctx) conn, err := receiver.Connect(connectCtx, cfg, handler) if err != nil { m.markConnectionStatus(cfg, false, err) @@ -200,7 +196,7 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error running := existing.connection != nil && existing.connection.Running() m.setConnectionStatusLocked(existing.config, running, nil) m.mu.Unlock() - _ = conn.Stop(context.Background()) + _ = conn.Stop(connectCtx) return nil } m.connections[cfg.ID] = &connectionEntry{ @@ -216,7 +212,7 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error // Disabled configs are stopped and removed; enabled configs are started or restarted. func (m *Manager) EnsureConnection(ctx context.Context, cfg ChannelConfig) error { if cfg.ID == "" { - return fmt.Errorf("config id is required") + return errors.New("config id is required") } if cfg.Disabled { return m.removeConnection(ctx, cfg.ID) @@ -329,7 +325,7 @@ func (m *Manager) stopAll(ctx context.Context) { func (m *Manager) Stop(ctx context.Context, configID string) error { configID = strings.TrimSpace(configID) if configID == "" { - return fmt.Errorf("config id is required") + return errors.New("config id is required") } m.mu.Lock() entry := m.connections[configID] @@ -349,7 +345,7 @@ func (m *Manager) Stop(ctx context.Context, configID string) error { func (m *Manager) StopByBot(ctx context.Context, botID string) error { botID = strings.TrimSpace(botID) if botID == "" { - return fmt.Errorf("bot id is required") + return errors.New("bot id is required") } m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/channel/helpers_test.go b/internal/channel/helpers_test.go index 46d90b13..2327e490 100644 --- a/internal/channel/helpers_test.go +++ b/internal/channel/helpers_test.go @@ -65,4 +65,3 @@ func TestBindingCriteriaFromIdentity(t *testing.T) { t.Fatalf("unexpected username: %s", criteria.Attribute("username")) } } - diff --git a/internal/channel/identities/service.go b/internal/channel/identities/service.go index e48ab1cb..4a67cbcb 100644 --- a/internal/channel/identities/service.go +++ b/internal/channel/identities/service.go @@ -21,9 +21,7 @@ type Service struct { logger *slog.Logger } -var ( - ErrChannelIdentityNotFound = errors.New("channel identity not found") -) +var ErrChannelIdentityNotFound = errors.New("channel identity not found") // NewService creates a new channel identity service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { @@ -39,12 +37,12 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { // Create creates a new channel identity for the given channel subject. func (s *Service) Create(ctx context.Context, channel, channelSubjectID, displayName string) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } channel = normalizeChannel(channel) channelSubjectID = strings.TrimSpace(channelSubjectID) if channel == "" || channelSubjectID == "" { - return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + return ChannelIdentity{}, errors.New("channel and channel_subject_id are required") } row, err := s.queries.CreateChannelIdentity(ctx, sqlc.CreateChannelIdentityParams{ UserID: pgtype.UUID{}, @@ -63,7 +61,7 @@ func (s *Service) Create(ctx context.Context, channel, channelSubjectID, display // GetByID returns a channel identity by its ID. func (s *Service) GetByID(ctx context.Context, channelIdentityID string) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } pgID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -82,7 +80,7 @@ func (s *Service) GetByID(ctx context.Context, channelIdentityID string) (Channe // Canonicalize validates and returns the same channel identity ID. func (s *Service) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { if s.queries == nil { - return "", fmt.Errorf("channel identity queries not configured") + return "", errors.New("channel identity queries not configured") } pgID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -102,12 +100,12 @@ func (s *Service) Canonicalize(ctx context.Context, channelIdentityID string) (s // Optional meta may contain avatar_url which is stored as a dedicated column. func (s *Service) ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string, meta map[string]any) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } channel = normalizeChannel(channel) channelSubjectID = strings.TrimSpace(channelSubjectID) if channel == "" || channelSubjectID == "" { - return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + return ChannelIdentity{}, errors.New("channel and channel_subject_id are required") } avatarURL := "" @@ -134,7 +132,7 @@ func (s *Service) ResolveByChannelIdentity(ctx context.Context, channel, channel // UpsertChannelIdentity creates or updates a channel identity mapping. func (s *Service) UpsertChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string, metadata map[string]any) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } channel = normalizeChannel(channel) channelSubjectID = strings.TrimSpace(channelSubjectID) @@ -166,7 +164,7 @@ func (s *Service) UpsertChannelIdentity(ctx context.Context, channel, channelSub // ListCanonicalChannelIdentities lists channel identities under the same linked user. func (s *Service) ListCanonicalChannelIdentities(ctx context.Context, channelIdentityID string) ([]ChannelIdentity, error) { if s.queries == nil { - return nil, fmt.Errorf("channel identity queries not configured") + return nil, errors.New("channel identity queries not configured") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -196,7 +194,7 @@ func (s *Service) ListCanonicalChannelIdentities(ctx context.Context, channelIde // ListUserChannelIdentities lists all channel identities linked to a user. func (s *Service) ListUserChannelIdentities(ctx context.Context, userID string) ([]ChannelIdentity, error) { if s.queries == nil { - return nil, fmt.Errorf("channel identity queries not configured") + return nil, errors.New("channel identity queries not configured") } pgUserID, err := db.ParseUUID(userID) if err != nil { @@ -216,7 +214,7 @@ func (s *Service) ListUserChannelIdentities(ctx context.Context, userID string) // GetLinkedUserID returns the linked user ID for a channel identity. func (s *Service) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { if s.queries == nil { - return "", fmt.Errorf("channel identity queries not configured") + return "", errors.New("channel identity queries not configured") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -238,7 +236,7 @@ func (s *Service) GetLinkedUserID(ctx context.Context, channelIdentityID string) // LinkChannelIdentityToUser binds a channel identity to a user. func (s *Service) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { if s.queries == nil { - return fmt.Errorf("channel identity queries not configured") + return errors.New("channel identity queries not configured") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -264,9 +262,7 @@ func (s *Service) LinkChannelIdentityToUser(ctx context.Context, channelIdentity func toChannelIdentity(row sqlc.ChannelIdentity) ChannelIdentity { var metadata map[string]any if len(row.Metadata) > 0 { - if err := json.Unmarshal(row.Metadata, &metadata); err != nil { - slog.Warn("unmarshal channel identity metadata failed", slog.String("id", row.ID.String()), slog.Any("error", err)) - } + _ = json.Unmarshal(row.Metadata, &metadata) } if metadata == nil { metadata = map[string]any{} diff --git a/internal/channel/inbound.go b/internal/channel/inbound.go index 630d7aef..92ba2452 100644 --- a/internal/channel/inbound.go +++ b/internal/channel/inbound.go @@ -2,12 +2,11 @@ package channel import ( "context" - "fmt" + "errors" "log/slog" ) type inboundTask struct { - ctx context.Context cfg ChannelConfig msg InboundMessage } @@ -15,17 +14,13 @@ type inboundTask struct { // HandleInbound enqueues an inbound message for asynchronous processing by the worker pool. func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { if m.processor == nil { - return fmt.Errorf("inbound processor not configured") - } - if ctx == nil { - ctx = context.Background() + return errors.New("inbound processor not configured") } m.startInboundWorkers(ctx) if m.inboundCtx != nil && m.inboundCtx.Err() != nil { - return fmt.Errorf("inbound dispatcher stopped") + return errors.New("inbound dispatcher stopped") } task := inboundTask{ - ctx: context.WithoutCancel(ctx), cfg: cfg, msg: msg, } @@ -33,13 +28,13 @@ func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg Inbo case m.inboundQueue <- task: return nil default: - return fmt.Errorf("inbound queue full") + return errors.New("inbound queue full") } } func (m *Manager) handleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { if m.processor == nil { - return fmt.Errorf("inbound processor not configured") + return errors.New("inbound processor not configured") } sender := m.newReplySender(cfg, msg.Channel) if err := m.processor.HandleInbound(ctx, cfg, msg, sender); err != nil { @@ -53,13 +48,11 @@ func (m *Manager) handleInbound(ctx context.Context, cfg ChannelConfig, msg Inbo func (m *Manager) startInboundWorkers(ctx context.Context) { m.inboundOnce.Do(func() { - workerCtx := ctx - if workerCtx == nil { - workerCtx = context.Background() - } - m.inboundCtx, m.inboundCancel = context.WithCancel(workerCtx) + workerCtx := context.WithoutCancel(ctx) + inboundCtx, inboundCancel := context.WithCancel(workerCtx) + m.inboundCtx, m.inboundCancel = inboundCtx, inboundCancel for i := 0; i < m.inboundWorkers; i++ { - go m.runInboundWorker(m.inboundCtx) + go m.runInboundWorker(inboundCtx) } }) } @@ -70,7 +63,7 @@ func (m *Manager) runInboundWorker(ctx context.Context) { case <-ctx.Done(): return case task := <-m.inboundQueue: - if err := m.handleInbound(task.ctx, task.cfg, task.msg); err != nil { + if err := m.handleInbound(ctx, task.cfg, task.msg); err != nil { if m.logger != nil { m.logger.Error("inbound processing failed", slog.String("channel", task.msg.Channel.String()), slog.Any("error", err)) } diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index a9e98e91..1f66ec60 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -3,6 +3,7 @@ package inbound import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -31,9 +32,7 @@ const ( processingStatusTimeout = 60 * time.Second ) -var ( - whitespacePattern = regexp.MustCompile(`\s+`) -) +var whitespacePattern = regexp.MustCompile(`\s+`) // RouteResolver resolves and manages channel routes. type RouteResolver interface { @@ -137,10 +136,10 @@ func (p *ChannelInboundProcessor) SetInboxService(service *inbox.Service) { // HandleInbound processes an inbound channel message through identity resolution and chat gateway. func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error { if p.runner == nil { - return fmt.Errorf("channel inbound processor not configured") + return errors.New("channel inbound processor not configured") } if sender == nil { - return fmt.Errorf("reply sender not configured") + return errors.New("reply sender not configured") } text := buildInboundQuery(msg.Message, nil) if p.logger != nil { @@ -189,7 +188,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel // Resolve or create the route via channel_routes. if p.routeResolver == nil { - return fmt.Errorf("route resolver not configured") + return errors.New("route resolver not configured") } routeMetadata := buildRouteMetadata(msg, identity) resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{ @@ -306,7 +305,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel } target := strings.TrimSpace(msg.ReplyTarget) if target == "" { - err := fmt.Errorf("reply target missing") + err := errors.New("reply target missing") if statusNotifier != nil { if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) @@ -509,7 +508,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel attachmentsApplied := false for _, output := range outputs { outMessage := buildChannelMessage(output, desc.Capabilities) - if outMessage.IsEmpty() && !(len(outboundAttachments) > 0 && !attachmentsApplied) { + if outMessage.IsEmpty() && (len(outboundAttachments) == 0 || attachmentsApplied) { continue } plainText := strings.TrimSpace(outMessage.PlainText()) @@ -1324,7 +1323,7 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool { // requireIdentity resolves identity for the current message. Always resolves from msg so each sender is identified correctly (no reuse of context state across messages). func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { if p.identity == nil { - return IdentityState{}, fmt.Errorf("identity resolver not configured") + return IdentityState{}, errors.New("identity resolver not configured") } return p.identity.Resolve(ctx, cfg, msg) } @@ -1340,7 +1339,7 @@ func (p *ChannelInboundProcessor) resolveProcessingStatusNotifier(channelType ch return notifier } -func (p *ChannelInboundProcessor) notifyProcessingStarted( +func (*ChannelInboundProcessor) notifyProcessingStarted( ctx context.Context, notifier channel.ProcessingStatusNotifier, cfg channel.ChannelConfig, @@ -1355,7 +1354,7 @@ func (p *ChannelInboundProcessor) notifyProcessingStarted( return notifier.ProcessingStarted(statusCtx, cfg, msg, info) } -func (p *ChannelInboundProcessor) notifyProcessingCompleted( +func (*ChannelInboundProcessor) notifyProcessingCompleted( ctx context.Context, notifier channel.ProcessingStatusNotifier, cfg channel.ChannelConfig, @@ -1371,7 +1370,7 @@ func (p *ChannelInboundProcessor) notifyProcessingCompleted( return notifier.ProcessingCompleted(statusCtx, cfg, msg, info, handle) } -func (p *ChannelInboundProcessor) notifyProcessingFailed( +func (*ChannelInboundProcessor) notifyProcessingFailed( ctx context.Context, notifier channel.ProcessingStatusNotifier, cfg channel.ChannelConfig, @@ -1570,7 +1569,7 @@ func (p *ChannelInboundProcessor) loadInboundAttachmentPayload( } platformKey := strings.TrimSpace(att.PlatformKey) if platformKey == "" { - return inboundAttachmentPayload{}, fmt.Errorf("attachment has no ingestible payload") + return inboundAttachmentPayload{}, errors.New("attachment has no ingestible payload") } resolver := p.resolveAttachmentResolver(msg.Channel) if resolver == nil { @@ -1581,7 +1580,7 @@ func (p *ChannelInboundProcessor) loadInboundAttachmentPayload( return inboundAttachmentPayload{}, fmt.Errorf("resolve attachment by platform key: %w", err) } if resolved.Reader == nil { - return inboundAttachmentPayload{}, fmt.Errorf("resolved attachment reader is nil") + return inboundAttachmentPayload{}, errors.New("resolved attachment reader is nil") } mime := strings.TrimSpace(att.Mime) if mime == "" { @@ -1605,7 +1604,7 @@ func openInboundAttachmentURL(ctx context.Context, rawURL string) (inboundAttach return inboundAttachmentPayload{}, fmt.Errorf("build request: %w", err) } client := &http.Client{Timeout: 20 * time.Second} - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: URL is an attachment URL provided by the inbound channel adapter if err != nil { return inboundAttachmentPayload{}, fmt.Errorf("download attachment: %w", err) } @@ -1775,7 +1774,7 @@ func applyAssetToAttachment(asset media.Asset, botID string, item *channel.Attac // extractStorageKey derives the media storage key from a container-internal // access path. The expected path format is /data/media/. -func extractStorageKey(accessPath string, botID string) string { +func extractStorageKey(accessPath string, _ string) string { marker := filepath.Join("/data", "media") if !strings.HasSuffix(marker, "/") { marker += "/" @@ -1972,8 +1971,7 @@ func buildRouteMetadata(msg channel.InboundMessage, identity InboundIdentity) ma if v == "" { continue } - switch k { - case "username": + if k == "username" { m["sender_username"] = v } } diff --git a/internal/channel/inbound/channel_test.go b/internal/channel/inbound/channel_test.go index 75ccc0fc..dbc2febc 100644 --- a/internal/channel/inbound/channel_test.go +++ b/internal/channel/inbound/channel_test.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/json" "errors" - "fmt" "io" "log/slog" "net/http" @@ -29,7 +28,7 @@ type fakeChatGateway struct { onChat func(conversation.ChatRequest) } -func (f *fakeChatGateway) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) { +func (f *fakeChatGateway) Chat(_ context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) { f.gotReq = req if f.onChat != nil { f.onChat(req) @@ -37,7 +36,7 @@ func (f *fakeChatGateway) Chat(ctx context.Context, req conversation.ChatRequest return f.resp, f.err } -func (f *fakeChatGateway) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) { +func (f *fakeChatGateway) StreamChat(_ context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) { f.gotReq = req if f.onChat != nil { f.onChat(req) @@ -63,7 +62,7 @@ func (f *fakeChatGateway) StreamChat(ctx context.Context, req conversation.ChatR return chunks, errs } -func (f *fakeChatGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { +func (*fakeChatGateway) TriggerSchedule(_ context.Context, _ string, _ schedule.TriggerPayload, _ string) error { return nil } @@ -72,12 +71,12 @@ type fakeReplySender struct { events []channel.StreamEvent } -func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) error { +func (s *fakeReplySender) Send(_ context.Context, msg channel.OutboundMessage) error { s.sent = append(s.sent, msg) return nil } -func (s *fakeReplySender) OpenStream(ctx context.Context, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (s *fakeReplySender) OpenStream(_ context.Context, target string, _ channel.StreamOptions) (channel.OutboundStream, error) { return &fakeOutboundStream{ sender: s, target: strings.TrimSpace(target), @@ -89,7 +88,7 @@ type fakeOutboundStream struct { target string } -func (s *fakeOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { +func (s *fakeOutboundStream) Push(_ context.Context, event channel.StreamEvent) error { if s == nil || s.sender == nil { return nil } @@ -103,7 +102,7 @@ func (s *fakeOutboundStream) Push(ctx context.Context, event channel.StreamEvent return nil } -func (s *fakeOutboundStream) Close(ctx context.Context) error { +func (*fakeOutboundStream) Close(_ context.Context) error { return nil } @@ -119,20 +118,20 @@ type fakeProcessingStatusNotifier struct { failedCause error } -func (n *fakeProcessingStatusNotifier) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (n *fakeProcessingStatusNotifier) ProcessingStarted(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { n.events = append(n.events, "started") n.info = append(n.info, info) return n.startedHandle, n.startedErr } -func (n *fakeProcessingStatusNotifier) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (n *fakeProcessingStatusNotifier) ProcessingCompleted(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { n.events = append(n.events, "completed") n.info = append(n.info, info) n.completedSeen = handle return n.completedErr } -func (n *fakeProcessingStatusNotifier) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (n *fakeProcessingStatusNotifier) ProcessingFailed(_ context.Context, _ channel.ChannelConfig, _ channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { n.events = append(n.events, "failed") n.info = append(n.info, info) n.failedSeen = handle @@ -144,11 +143,11 @@ type fakeProcessingStatusAdapter struct { notifier *fakeProcessingStatusNotifier } -func (a *fakeProcessingStatusAdapter) Type() channel.ChannelType { +func (*fakeProcessingStatusAdapter) Type() channel.ChannelType { return channel.ChannelType("feishu") } -func (a *fakeProcessingStatusAdapter) Descriptor() channel.Descriptor { +func (*fakeProcessingStatusAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: channel.ChannelType("feishu"), Capabilities: channel.ChannelCapabilities{ @@ -188,7 +187,7 @@ type fakeMediaIngestor struct { storageKeyErr error } -func (f *fakeMediaIngestor) Ingest(ctx context.Context, input media.IngestInput) (media.Asset, error) { +func (f *fakeMediaIngestor) Ingest(_ context.Context, input media.IngestInput) (media.Asset, error) { f.calls++ f.inputs = append(f.inputs, input) if input.Reader != nil { @@ -217,21 +216,21 @@ func (f *fakeMediaIngestor) GetByStorageKey(_ context.Context, _, _ string) (med return f.storageKeyAsset, f.storageKeyErr } -func (f *fakeMediaIngestor) IngestContainerFile(_ context.Context, _, _ string) (media.Asset, error) { - return media.Asset{}, fmt.Errorf("not implemented in test") +func (*fakeMediaIngestor) IngestContainerFile(_ context.Context, _, _ string) (media.Asset, error) { + return media.Asset{}, errors.New("not implemented in test") } -func (f *fakeMediaIngestor) AccessPath(asset media.Asset) string { +func (*fakeMediaIngestor) AccessPath(asset media.Asset) string { return "/data/media/" + asset.StorageKey } type fakeAttachmentResolverAdapter struct{} -func (a *fakeAttachmentResolverAdapter) Type() channel.ChannelType { +func (*fakeAttachmentResolverAdapter) Type() channel.ChannelType { return channel.ChannelType("resolver-test") } -func (a *fakeAttachmentResolverAdapter) Descriptor() channel.Descriptor { +func (*fakeAttachmentResolverAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: channel.ChannelType("resolver-test"), DisplayName: "ResolverTest", @@ -242,7 +241,7 @@ func (a *fakeAttachmentResolverAdapter) Descriptor() channel.Descriptor { } } -func (a *fakeAttachmentResolverAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) { +func (*fakeAttachmentResolverAdapter) ResolveAttachment(_ context.Context, _ channel.ChannelConfig, _ channel.Attachment) (channel.AttachmentPayload, error) { return channel.AttachmentPayload{ Reader: io.NopCloser(strings.NewReader("resolver-bytes")), Mime: "application/octet-stream", @@ -251,14 +250,14 @@ func (a *fakeAttachmentResolverAdapter) ResolveAttachment(ctx context.Context, c }, nil } -func (f *fakeChatService) ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error) { +func (f *fakeChatService) ResolveConversation(_ context.Context, _ route.ResolveInput) (route.ResolveConversationResult, error) { if f.resolveErr != nil { return route.ResolveConversationResult{}, f.resolveErr } return f.resolveResult, nil } -func (f *fakeChatService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { +func (f *fakeChatService) Persist(_ context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { f.persistedIn = append(f.persistedIn, input) msg := messagepkg.Message{ BotID: input.BotID, @@ -542,14 +541,8 @@ func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) { if len(sender.sent) != 0 { t.Fatalf("group passive sync should not send reply: %+v", sender.sent) } - if len(chatSvc.persisted) != 1 { - t.Fatalf("expected 1 passive persisted message, got: %d", len(chatSvc.persisted)) - } - if chatSvc.persisted[0].Role != "user" { - t.Fatalf("expected persisted role user, got: %s", chatSvc.persisted[0].Role) - } - if chatSvc.persisted[0].BotID != "bot-1" { - t.Fatalf("expected passive persisted bot_id bot-1, got: %s", chatSvc.persisted[0].BotID) + if len(chatSvc.persisted) != 0 { + t.Fatalf("group passive sync should not persist to messages directly, got: %d", len(chatSvc.persisted)) } } @@ -621,11 +614,11 @@ func TestChannelInboundProcessorPersistsAttachmentAssetRefs(t *testing.T) { Text: "attachment test", Attachments: []channel.Attachment{ { - Type: channel.AttachmentImage, - URL: "https://example.com/img.png", + Type: channel.AttachmentImage, + URL: "https://example.com/img.png", ContentHash: "asset-1", - Name: "img.png", - Mime: "image/png", + Name: "img.png", + Mime: "image/png", }, }, }, @@ -893,7 +886,7 @@ func TestChannelInboundProcessorProcessingStatusSuccessLifecycle(t *testing.T) { {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, - onChat: func(req conversation.ChatRequest) { + onChat: func(_ conversation.ChatRequest) { if len(notifier.events) != 1 || notifier.events[0] != "started" { t.Fatalf("expected started before chat call, got events: %+v", notifier.events) } @@ -1065,7 +1058,7 @@ func TestChannelInboundProcessorProcessingFailedNotifyErrorDoesNotOverrideChatEr func TestDownloadInboundAttachmentURLTooLarge(t *testing.T) { t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Length", "999999999") _, _ = w.Write([]byte("x")) @@ -1464,11 +1457,11 @@ func TestMapChannelToChatAttachments(t *testing.T) { attachments := []channel.Attachment{ { - Type: channel.AttachmentImage, + Type: channel.AttachmentImage, ContentHash: "asset-1", - URL: "/data/media/ab/c.png", - Base64: "AAAA", - Mime: "image/png", + URL: "/data/media/ab/c.png", + Base64: "AAAA", + Mime: "image/png", }, { Type: channel.AttachmentFile, diff --git a/internal/channel/inbound/identity.go b/internal/channel/inbound/identity.go index e010bb9b..3ab94224 100644 --- a/internal/channel/inbound/identity.go +++ b/internal/channel/inbound/identity.go @@ -158,7 +158,7 @@ func (r *IdentityResolver) Middleware() channel.Middleware { // 2. Authorization: bot membership check with guest/preauth fallback func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { if r.channelIdentities == nil { - return IdentityState{}, fmt.Errorf("identity resolver not configured") + return IdentityState{}, errors.New("identity resolver not configured") } botID := strings.TrimSpace(msg.BotID) @@ -183,7 +183,7 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi // Phase 1: Global identity resolution (unconditional). if subjectID == "" { - return state, fmt.Errorf("cannot resolve identity: no channel_subject_id") + return state, errors.New("cannot resolve identity: no channel_subject_id") } channelIdentityID, linkedUserID, err := r.resolveIdentityWithLinkedUser(ctx, msg, subjectID, displayName, avatarURL) @@ -289,7 +289,7 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi func (r *IdentityResolver) resolveIdentityWithLinkedUser(ctx context.Context, msg channel.InboundMessage, primarySubjectID, displayName, avatarURL string) (string, string, error) { candidates := identitySubjectCandidates(msg, primarySubjectID) if len(candidates) == 0 { - return "", "", fmt.Errorf("cannot resolve identity: no channel_subject_id") + return "", "", errors.New("cannot resolve identity: no channel_subject_id") } var meta map[string]any @@ -504,9 +504,6 @@ func (r *IdentityResolver) resolveProfileFromDirectory(ctx context.Context, cfg if !ok || directoryAdapter == nil { return "", "" } - if ctx == nil { - ctx = context.Background() - } lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() entry, err := directoryAdapter.ResolveEntry(lookupCtx, cfg, subjectID, channel.DirectoryEntryUser) diff --git a/internal/channel/inbound/identity_test.go b/internal/channel/inbound/identity_test.go index 6e2a3eeb..76604a4f 100644 --- a/internal/channel/inbound/identity_test.go +++ b/internal/channel/inbound/identity_test.go @@ -24,7 +24,7 @@ type fakeChannelIdentityService struct { lastMeta map[string]any } -func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Context, platform, externalID, displayName string, meta map[string]any) (identities.ChannelIdentity, error) { +func (f *fakeChannelIdentityService) ResolveByChannelIdentity(_ context.Context, _, externalID, displayName string, meta map[string]any) (identities.ChannelIdentity, error) { f.calls++ f.lastDisplayName = displayName f.lastMeta = meta @@ -40,7 +40,7 @@ func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Contex return f.channelIdentity, nil } -func (f *fakeChannelIdentityService) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { +func (f *fakeChannelIdentityService) Canonicalize(_ context.Context, channelIdentityID string) (string, error) { if f.canonical != nil { if value, ok := f.canonical[channelIdentityID]; ok { return value, nil @@ -49,7 +49,7 @@ func (f *fakeChannelIdentityService) Canonicalize(ctx context.Context, channelId return channelIdentityID, nil } -func (f *fakeChannelIdentityService) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { +func (f *fakeChannelIdentityService) GetLinkedUserID(_ context.Context, channelIdentityID string) (string, error) { if f.linked != nil { if value, ok := f.linked[channelIdentityID]; ok { return value, nil @@ -60,7 +60,7 @@ func (f *fakeChannelIdentityService) GetLinkedUserID(ctx context.Context, channe return channelIdentityID, nil } -func (f *fakeChannelIdentityService) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { +func (f *fakeChannelIdentityService) LinkChannelIdentityToUser(_ context.Context, channelIdentityID, userID string) error { if f.linked == nil { f.linked = map[string]string{} } @@ -73,11 +73,11 @@ type fakeMemberService struct { upsertCalled bool } -func (f *fakeMemberService) IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) { +func (f *fakeMemberService) IsMember(_ context.Context, _, _ string) (bool, error) { return f.isMember, nil } -func (f *fakeMemberService) UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error { +func (f *fakeMemberService) UpsertMemberSimple(_ context.Context, _, _, _ string) error { f.upsertCalled = true return nil } @@ -89,21 +89,21 @@ type fakePolicyService struct { err error } -func (f *fakePolicyService) AllowGuest(ctx context.Context, botID string) (bool, error) { +func (f *fakePolicyService) AllowGuest(_ context.Context, _ string) (bool, error) { if f.err != nil { return false, f.err } return f.allow, nil } -func (f *fakePolicyService) BotType(ctx context.Context, botID string) (string, error) { +func (f *fakePolicyService) BotType(_ context.Context, _ string) (string, error) { if f.err != nil { return "", f.err } return f.botType, nil } -func (f *fakePolicyService) BotOwnerUserID(ctx context.Context, botID string) (string, error) { +func (f *fakePolicyService) BotOwnerUserID(_ context.Context, _ string) (string, error) { if f.err != nil { return "", f.err } @@ -116,7 +116,7 @@ type fakePreauthServiceIdentity struct { markUsed bool } -func (f *fakePreauthServiceIdentity) Get(ctx context.Context, token string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) Get(_ context.Context, token string) (preauth.Key, error) { if f.err != nil { return preauth.Key{}, f.err } @@ -126,7 +126,7 @@ func (f *fakePreauthServiceIdentity) Get(ctx context.Context, token string) (pre return f.key, nil } -func (f *fakePreauthServiceIdentity) MarkUsed(ctx context.Context, id string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) MarkUsed(_ context.Context, _ string) (preauth.Key, error) { f.markUsed = true return f.key, nil } @@ -139,7 +139,7 @@ type fakeBindService struct { onConsume func(channelChannelIdentityID string) } -func (f *fakeBindService) Get(ctx context.Context, token string) (bind.Code, error) { +func (f *fakeBindService) Get(_ context.Context, token string) (bind.Code, error) { if f.getErr != nil { return bind.Code{}, f.getErr } @@ -149,7 +149,7 @@ func (f *fakeBindService) Get(ctx context.Context, token string) (bind.Code, err return f.code, nil } -func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelChannelIdentityID string) error { +func (f *fakeBindService) Consume(_ context.Context, _ bind.Code, channelChannelIdentityID string) error { f.consumeCalled = true if f.onConsume != nil { f.onConsume(channelChannelIdentityID) @@ -168,21 +168,21 @@ func (f *fakeDirectoryAdapter) Type() channel.ChannelType { func (f *fakeDirectoryAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ - Type: f.channelType, - DisplayName: "FakeDirectory", + Type: f.channelType, + DisplayName: "FakeDirectory", Capabilities: channel.ChannelCapabilities{}, } } -func (f *fakeDirectoryAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*fakeDirectoryAdapter) ListPeers(_ context.Context, _ channel.ChannelConfig, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (f *fakeDirectoryAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*fakeDirectoryAdapter) ListGroups(_ context.Context, _ channel.ChannelConfig, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (f *fakeDirectoryAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*fakeDirectoryAdapter) ListGroupMembers(_ context.Context, _ channel.ChannelConfig, _ string, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } @@ -225,7 +225,7 @@ func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) { registry := channel.NewRegistry() directoryAdapter := &fakeDirectoryAdapter{ channelType: channel.ChannelType("feishu"), - resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + resolveFn: func(_ context.Context, _ channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { if kind != channel.DirectoryEntryUser { t.Fatalf("expected kind user, got %s", kind) } @@ -275,7 +275,7 @@ func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testin registry := channel.NewRegistry() directoryAdapter := &fakeDirectoryAdapter{ channelType: channel.ChannelType("feishu"), - resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + resolveFn: func(_ context.Context, _ channel.ChannelConfig, _ string, _ channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { return channel.DirectoryEntry{}, errors.New("lookup failed") }, } @@ -317,7 +317,7 @@ func TestIdentityResolverDirectoryAvatarURLPropagated(t *testing.T) { registry := channel.NewRegistry() directoryAdapter := &fakeDirectoryAdapter{ channelType: channel.ChannelType("feishu"), - resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + resolveFn: func(_ context.Context, _ channel.ChannelConfig, _ string, _ channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { return channel.DirectoryEntry{ Kind: channel.DirectoryEntryUser, Name: "Avatar User", diff --git a/internal/channel/inbound_test.go b/internal/channel/inbound_test.go index 2135d5ba..819d662c 100644 --- a/internal/channel/inbound_test.go +++ b/internal/channel/inbound_test.go @@ -2,7 +2,7 @@ package channel import ( "context" - "fmt" + "errors" "log/slog" "testing" ) @@ -13,8 +13,8 @@ type mockAdapter struct { streamEvents []StreamEvent } -func (m *mockAdapter) Type() ChannelType { return ChannelType("test") } -func (m *mockAdapter) Descriptor() Descriptor { +func (*mockAdapter) Type() ChannelType { return ChannelType("test") } +func (*mockAdapter) Descriptor() Descriptor { return Descriptor{ Type: ChannelType("test"), DisplayName: "Test", @@ -25,12 +25,13 @@ func (m *mockAdapter) Descriptor() Descriptor { }, } } -func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { + +func (m *mockAdapter) Send(_ context.Context, _ ChannelConfig, msg OutboundMessage) error { m.sentMessages = append(m.sentMessages, msg) return nil } -func (m *mockAdapter) OpenStream(ctx context.Context, cfg ChannelConfig, target string, opts StreamOptions) (OutboundStream, error) { +func (m *mockAdapter) OpenStream(_ context.Context, _ ChannelConfig, _ string, _ StreamOptions) (OutboundStream, error) { return &mockAdapterStream{adapter: m}, nil } @@ -38,7 +39,7 @@ type mockAdapterStream struct { adapter *mockAdapter } -func (s *mockAdapterStream) Push(ctx context.Context, event StreamEvent) error { +func (s *mockAdapterStream) Push(_ context.Context, event StreamEvent) error { if s == nil || s.adapter == nil { return nil } @@ -52,7 +53,7 @@ func (s *mockAdapterStream) Push(ctx context.Context, event StreamEvent) error { return nil } -func (s *mockAdapterStream) Close(ctx context.Context) error { +func (*mockAdapterStream) Close(_ context.Context) error { return nil } @@ -73,14 +74,14 @@ func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelCon return nil } if sender == nil { - return fmt.Errorf("sender missing") + return errors.New("sender missing") } return sender.Send(ctx, *f.resp) } type fakeInboundStreamProcessor struct{} -func (f *fakeInboundStreamProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { +func (*fakeInboundStreamProcessor) HandleInbound(ctx context.Context, _ ChannelConfig, _ InboundMessage, sender StreamReplySender) error { stream, err := sender.OpenStream(ctx, "stream-target", StreamOptions{}) if err != nil { return err diff --git a/internal/channel/lifecycle.go b/internal/channel/lifecycle.go index dcea4fb8..7624d05d 100644 --- a/internal/channel/lifecycle.go +++ b/internal/channel/lifecycle.go @@ -44,14 +44,14 @@ func NewLifecycle(store LifecycleStore, controller ConnectionController) *Lifecy // For disabled=false, it stores config then starts connection; on start failure it rolls back. func (s *Lifecycle) UpsertBotChannelConfig(ctx context.Context, botID string, channelType ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { if s.store == nil { - return ChannelConfig{}, fmt.Errorf("channel lifecycle store not configured") + return ChannelConfig{}, errors.New("channel lifecycle store not configured") } disabled := false if req.Disabled != nil { disabled = *req.Disabled } if !disabled && s.controller == nil { - return ChannelConfig{}, fmt.Errorf("channel connection controller not configured") + return ChannelConfig{}, errors.New("channel connection controller not configured") } previous, hadPrevious, err := s.getPreviousConfig(ctx, botID, channelType) @@ -73,7 +73,7 @@ func (s *Lifecycle) UpsertBotChannelConfig(ctx context.Context, botID string, ch if err := s.controller.EnsureConnection(ctx, updated); err != nil { if rollbackErr := s.rollbackUpsert(ctx, botID, channelType, hadPrevious, previous); rollbackErr != nil { - return ChannelConfig{}, fmt.Errorf("%w (rollback failed: %v): %w", ErrEnableChannelFailed, rollbackErr, err) + return ChannelConfig{}, fmt.Errorf("%w (rollback failed: %w): %w", ErrEnableChannelFailed, rollbackErr, err) } return ChannelConfig{}, fmt.Errorf("%w: %w", ErrEnableChannelFailed, err) } @@ -83,7 +83,7 @@ func (s *Lifecycle) UpsertBotChannelConfig(ctx context.Context, botID string, ch // DeleteBotChannelConfig removes persisted config and stops active runtime connection. func (s *Lifecycle) DeleteBotChannelConfig(ctx context.Context, botID string, channelType ChannelType) error { if s.store == nil { - return fmt.Errorf("channel lifecycle store not configured") + return errors.New("channel lifecycle store not configured") } if err := s.store.DeleteConfig(ctx, botID, channelType); err != nil { return err @@ -97,10 +97,10 @@ func (s *Lifecycle) DeleteBotChannelConfig(ctx context.Context, botID string, ch // SetBotChannelStatus updates only the disabled status and applies runtime lifecycle. func (s *Lifecycle) SetBotChannelStatus(ctx context.Context, botID string, channelType ChannelType, disabled bool) (ChannelConfig, error) { if s.store == nil { - return ChannelConfig{}, fmt.Errorf("channel lifecycle store not configured") + return ChannelConfig{}, errors.New("channel lifecycle store not configured") } if s.controller == nil { - return ChannelConfig{}, fmt.Errorf("channel connection controller not configured") + return ChannelConfig{}, errors.New("channel connection controller not configured") } updated, err := s.store.UpdateConfigDisabled(ctx, botID, channelType, disabled) @@ -114,7 +114,7 @@ func (s *Lifecycle) SetBotChannelStatus(ctx context.Context, botID string, chann if err := s.controller.EnsureConnection(ctx, updated); err != nil { if _, rollbackErr := s.store.UpdateConfigDisabled(ctx, botID, channelType, true); rollbackErr != nil { - return ChannelConfig{}, fmt.Errorf("%w (status rollback failed: %v): %w", ErrEnableChannelFailed, rollbackErr, err) + return ChannelConfig{}, fmt.Errorf("%w (status rollback failed: %w): %w", ErrEnableChannelFailed, rollbackErr, err) } s.controller.RemoveConnection(ctx, botID, channelType) return ChannelConfig{}, fmt.Errorf("%w: %w", ErrEnableChannelFailed, err) diff --git a/internal/channel/lifecycle_test.go b/internal/channel/lifecycle_test.go index 9c5bb096..366296b7 100644 --- a/internal/channel/lifecycle_test.go +++ b/internal/channel/lifecycle_test.go @@ -65,12 +65,12 @@ func TestLifecycleUpsertDisabledRemovesConnection(t *testing.T) { removeCalled := false store := &fakeLifecycleStore{ - upsertFunc: func(ctx context.Context, botID string, channelType ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { + upsertFunc: func(_ context.Context, botID string, channelType ChannelType, _ UpsertConfigRequest) (ChannelConfig, error) { return ChannelConfig{ID: "cfg-1", BotID: botID, ChannelType: channelType, Disabled: true}, nil }, } controller := &fakeConnectionController{ - removeFunc: func(ctx context.Context, botID string, channelType ChannelType) { + removeFunc: func(_ context.Context, _ string, _ ChannelType) { removeCalled = true }, } @@ -112,10 +112,10 @@ func TestLifecycleUpsertEnableFailureRollsBackToPrevious(t *testing.T) { upsertCalls := 0 ensureCalls := 0 store := &fakeLifecycleStore{ - resolveFunc: func(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { + resolveFunc: func(_ context.Context, _ string, _ ChannelType) (ChannelConfig, error) { return previous, nil }, - upsertFunc: func(ctx context.Context, botID string, channelType ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { + upsertFunc: func(_ context.Context, _ string, _ ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { upsertCalls++ if upsertCalls == 1 { return newConfig, nil @@ -127,7 +127,7 @@ func TestLifecycleUpsertEnableFailureRollsBackToPrevious(t *testing.T) { }, } controller := &fakeConnectionController{ - ensureFunc: func(ctx context.Context, cfg ChannelConfig) error { + ensureFunc: func(_ context.Context, _ ChannelConfig) error { ensureCalls++ if ensureCalls == 1 { return errors.New("dial failed") @@ -158,10 +158,10 @@ func TestLifecycleUpsertEnableFailureWithoutPreviousDeletesNewConfig(t *testing. deleteCalls := 0 store := &fakeLifecycleStore{ - resolveFunc: func(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { + resolveFunc: func(_ context.Context, _ string, _ ChannelType) (ChannelConfig, error) { return ChannelConfig{}, ErrChannelConfigNotFound }, - upsertFunc: func(ctx context.Context, botID string, channelType ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { + upsertFunc: func(_ context.Context, botID string, channelType ChannelType, _ UpsertConfigRequest) (ChannelConfig, error) { return ChannelConfig{ ID: "cfg-new", BotID: botID, @@ -169,13 +169,13 @@ func TestLifecycleUpsertEnableFailureWithoutPreviousDeletesNewConfig(t *testing. Credentials: map[string]any{"botToken": "new"}, }, nil }, - deleteFunc: func(ctx context.Context, botID string, channelType ChannelType) error { + deleteFunc: func(_ context.Context, _ string, _ ChannelType) error { deleteCalls++ return nil }, } controller := &fakeConnectionController{ - ensureFunc: func(ctx context.Context, cfg ChannelConfig) error { + ensureFunc: func(_ context.Context, _ ChannelConfig) error { return errors.New("start failed") }, } @@ -199,12 +199,12 @@ func TestLifecycleDeleteStopsConnection(t *testing.T) { removeCalled := false store := &fakeLifecycleStore{ - deleteFunc: func(ctx context.Context, botID string, channelType ChannelType) error { + deleteFunc: func(_ context.Context, _ string, _ ChannelType) error { return nil }, } controller := &fakeConnectionController{ - removeFunc: func(ctx context.Context, botID string, channelType ChannelType) { + removeFunc: func(_ context.Context, _ string, _ ChannelType) { removeCalled = true }, } @@ -223,7 +223,7 @@ func TestLifecycleSetBotChannelStatusDisable(t *testing.T) { removeCalled := false store := &fakeLifecycleStore{ - statusFunc: func(ctx context.Context, botID string, channelType ChannelType, disabled bool) (ChannelConfig, error) { + statusFunc: func(_ context.Context, botID string, channelType ChannelType, disabled bool) (ChannelConfig, error) { if !disabled { t.Fatalf("expected disabled=true update") } @@ -231,7 +231,7 @@ func TestLifecycleSetBotChannelStatusDisable(t *testing.T) { }, } controller := &fakeConnectionController{ - removeFunc: func(ctx context.Context, botID string, channelType ChannelType) { + removeFunc: func(_ context.Context, _ string, _ ChannelType) { removeCalled = true }, } @@ -255,7 +255,7 @@ func TestLifecycleSetBotChannelStatusEnableFailureRollsBack(t *testing.T) { statusCalls := 0 removeCalled := false store := &fakeLifecycleStore{ - statusFunc: func(ctx context.Context, botID string, channelType ChannelType, disabled bool) (ChannelConfig, error) { + statusFunc: func(_ context.Context, botID string, channelType ChannelType, disabled bool) (ChannelConfig, error) { statusCalls++ if statusCalls == 1 && disabled { t.Fatalf("first status update should enable config") @@ -272,10 +272,10 @@ func TestLifecycleSetBotChannelStatusEnableFailureRollsBack(t *testing.T) { }, } controller := &fakeConnectionController{ - ensureFunc: func(ctx context.Context, cfg ChannelConfig) error { + ensureFunc: func(_ context.Context, _ ChannelConfig) error { return errors.New("start failed") }, - removeFunc: func(ctx context.Context, botID string, channelType ChannelType) { + removeFunc: func(_ context.Context, _ string, _ ChannelType) { removeCalled = true }, } diff --git a/internal/channel/manager.go b/internal/channel/manager.go index cf36afce..ad510875 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -135,9 +135,6 @@ func (m *Manager) AddAdapter(ctx context.Context, adapter Adapter) { // RemoveAdapter unregisters an adapter and stops all its active connections. func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { - if ctx == nil { - ctx = context.Background() - } m.mu.Lock() for id, entry := range m.connections { if entry != nil && entry.config.ChannelType == channelType { @@ -190,7 +187,7 @@ func (m *Manager) Start(ctx context.Context) { // Send delivers an outbound message to the specified channel, resolving target and config automatically. func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelType, req SendRequest) error { if m.service == nil { - return fmt.Errorf("channel manager not configured") + return errors.New("channel manager not configured") } sender, ok := m.registry.GetSender(channelType) if !ok { @@ -204,14 +201,14 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp if target == "" { targetChannelIdentityID := strings.TrimSpace(req.ChannelIdentityID) if targetChannelIdentityID == "" { - return fmt.Errorf("target or channel_identity_id is required") + return errors.New("target or channel_identity_id is required") } userCfg, err := m.service.GetChannelIdentityConfig(ctx, targetChannelIdentityID, channelType) if err != nil { if m.logger != nil { m.logger.Warn("channel binding missing", slog.String("channel", channelType.String()), slog.String("channel_identity_id", targetChannelIdentityID)) } - return fmt.Errorf("channel binding required") + return errors.New("channel binding required") } target, err = m.registry.ResolveTargetFromUserConfig(channelType, userCfg.Config) if err != nil { @@ -222,7 +219,7 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp target = normalized } if req.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } if m.logger != nil { m.logger.Info("send outbound", slog.String("channel", channelType.String()), slog.String("bot_id", botID)) @@ -249,7 +246,7 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp // React adds or removes an emoji reaction on a channel message. func (m *Manager) React(ctx context.Context, botID string, channelType ChannelType, req ReactRequest) error { if m.service == nil { - return fmt.Errorf("channel manager not configured") + return errors.New("channel manager not configured") } reactor, ok := m.registry.GetReactor(channelType) if !ok { @@ -261,18 +258,18 @@ func (m *Manager) React(ctx context.Context, botID string, channelType ChannelTy } target := strings.TrimSpace(req.Target) if target == "" { - return fmt.Errorf("target is required for reactions") + return errors.New("target is required for reactions") } if normalized, ok := m.registry.NormalizeTarget(channelType, target); ok { target = normalized } messageID := strings.TrimSpace(req.MessageID) if messageID == "" { - return fmt.Errorf("message_id is required for reactions") + return errors.New("message_id is required for reactions") } emoji := strings.TrimSpace(req.Emoji) if !req.Remove && emoji == "" { - return fmt.Errorf("emoji is required when adding a reaction") + return errors.New("emoji is required when adding a reaction") } if m.logger != nil { m.logger.Info("react outbound", diff --git a/internal/channel/manager_integration_test.go b/internal/channel/manager_integration_test.go index fd45bb3e..13651047 100644 --- a/internal/channel/manager_integration_test.go +++ b/internal/channel/manager_integration_test.go @@ -3,8 +3,6 @@ package channel import ( "context" "errors" - "fmt" - "io" "log/slog" "strings" "sync" @@ -19,31 +17,31 @@ type fakeConfigStore struct { boundChannelIdentityID string } -func (f *fakeConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { +func (f *fakeConfigStore) ResolveEffectiveConfig(_ context.Context, _ string, _ ChannelType) (ChannelConfig, error) { return f.effectiveConfig, nil } -func (f *fakeConfigStore) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { +func (f *fakeConfigStore) GetChannelIdentityConfig(_ context.Context, _ string, _ ChannelType) (ChannelIdentityBinding, error) { if f.channelIdentityConfig.ID == "" && len(f.channelIdentityConfig.Config) == 0 { - return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") + return ChannelIdentityBinding{}, errors.New("channel user config not found") } return f.channelIdentityConfig, nil } -func (f *fakeConfigStore) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { +func (f *fakeConfigStore) UpsertChannelIdentityConfig(_ context.Context, _ string, _ ChannelType, _ UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { return f.channelIdentityConfig, nil } -func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { +func (f *fakeConfigStore) ListConfigsByType(_ context.Context, channelType ChannelType) ([]ChannelConfig, error) { if f.configsByType == nil { return nil, nil } return f.configsByType[channelType], nil } -func (f *fakeConfigStore) ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { +func (f *fakeConfigStore) ResolveChannelIdentityBinding(_ context.Context, _ ChannelType, _ BindingCriteria) (string, error) { if f.boundChannelIdentityID == "" { - return "", fmt.Errorf("channel user binding not found") + return "", errors.New("channel user binding not found") } return f.boundChannelIdentityID, nil } @@ -65,7 +63,7 @@ func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg return nil } if sender == nil { - return fmt.Errorf("sender missing") + return errors.New("sender missing") } return sender.Send(ctx, *f.resp) } @@ -88,17 +86,17 @@ func (f *fakeAdapter) Descriptor() Descriptor { return Descriptor{Type: f.channelType, DisplayName: "Fake", Capabilities: ChannelCapabilities{Text: true}} } -func (f *fakeAdapter) ResolveTarget(channelIdentityConfig map[string]any) (string, error) { +func (*fakeAdapter) ResolveTarget(channelIdentityConfig map[string]any) (string, error) { value := strings.TrimSpace(ReadString(channelIdentityConfig, "target")) if value == "" { - return "", fmt.Errorf("missing target") + return "", errors.New("missing target") } return "resolved:" + value, nil } -func (f *fakeAdapter) NormalizeTarget(raw string) string { return strings.TrimSpace(raw) } +func (*fakeAdapter) NormalizeTarget(raw string) string { return strings.TrimSpace(raw) } -func (f *fakeAdapter) Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) { +func (f *fakeAdapter) Connect(ctx context.Context, cfg ChannelConfig, _ InboundHandler) (Connection, error) { if f.connectErr != nil { return nil, f.connectErr } @@ -115,7 +113,7 @@ func (f *fakeAdapter) Connect(ctx context.Context, cfg ChannelConfig, handler In return NewConnection(cfg, stop), nil } -func (f *fakeAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { +func (f *fakeAdapter) Send(_ context.Context, _ ChannelConfig, msg OutboundMessage) error { f.mu.Lock() f.sent = append(f.sent, msg) f.mu.Unlock() @@ -125,7 +123,7 @@ func (f *fakeAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundM func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{} processor := &fakeInboundProcessorIntegration{ resp: &OutboundMessage{ @@ -178,7 +176,7 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { func TestManagerSendUsesBinding(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{ effectiveConfig: ChannelConfig{ ID: "cfg-1", @@ -220,7 +218,7 @@ func TestManagerSendUsesBinding(t *testing.T) { func TestManagerReconcileStartsAndStops(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{} reg := NewRegistry() adapter := &fakeAdapter{channelType: ChannelType("test")} @@ -261,7 +259,7 @@ func TestManagerReconcileStartsAndStops(t *testing.T) { func TestManagerConnectionStatusesByBotTracksConnectFailure(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{} reg := NewRegistry() adapter := &fakeAdapter{ @@ -295,7 +293,7 @@ func TestManagerConnectionStatusesByBotTracksConnectFailure(t *testing.T) { func TestManagerEnsureConnectionDetachesRequestContext(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{} reg := NewRegistry() adapter := &fakeAdapter{channelType: ChannelType("test")} diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go index 7dff33a3..c92e4b0d 100644 --- a/internal/channel/outbound.go +++ b/internal/channel/outbound.go @@ -2,6 +2,7 @@ package channel import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -195,7 +196,7 @@ func (m *Manager) resolveOutboundPolicy(channelType ChannelType) OutboundPolicy // buildOutboundMessages splits an outbound message into multiple messages based on the policy. func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]OutboundMessage, error) { if msg.Message.IsEmpty() { - return nil, fmt.Errorf("message is required") + return nil, errors.New("message is required") } normalized := normalizeOutboundMessage(msg.Message) chunker := policy.Chunker @@ -250,7 +251,7 @@ func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]Outbou } if len(textMessages) == 0 && len(attachmentMessages) == 0 { - return nil, fmt.Errorf("message is required") + return nil, errors.New("message is required") } if policy.MediaOrder == OutboundOrderTextFirst { return append(textMessages, attachmentMessages...), nil @@ -277,37 +278,37 @@ func validateMessageCapabilities(registry *Registry, channelType ChannelType, ms switch msg.Format { case MessageFormatPlain: if !caps.Text { - return fmt.Errorf("channel does not support plain text") + return errors.New("channel does not support plain text") } case MessageFormatMarkdown: if !caps.Markdown && !caps.RichText { - return fmt.Errorf("channel does not support markdown") + return errors.New("channel does not support markdown") } case MessageFormatRich: if !caps.RichText { - return fmt.Errorf("channel does not support rich text") + return errors.New("channel does not support rich text") } } if len(msg.Parts) > 0 && !caps.RichText { - return fmt.Errorf("channel does not support rich text") + return errors.New("channel does not support rich text") } if len(msg.Attachments) > 0 && !caps.Attachments { - return fmt.Errorf("channel does not support attachments") + return errors.New("channel does not support attachments") } if len(msg.Attachments) > 0 && requiresMedia(msg.Attachments) && !caps.Media { - return fmt.Errorf("channel does not support media") + return errors.New("channel does not support media") } if len(msg.Actions) > 0 && !caps.Buttons { - return fmt.Errorf("channel does not support actions") + return errors.New("channel does not support actions") } if msg.Thread != nil && !caps.Threads { - return fmt.Errorf("channel does not support threads") + return errors.New("channel does not support threads") } if msg.Reply != nil && !caps.Reply { - return fmt.Errorf("channel does not support reply") + return errors.New("channel does not support reply") } if strings.TrimSpace(msg.ID) != "" && !caps.Edit { - return fmt.Errorf("channel does not support edit") + return errors.New("channel does not support edit") } return nil } @@ -318,10 +319,10 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel } target := strings.TrimSpace(msg.Target) if target == "" { - return fmt.Errorf("target is required") + return errors.New("target is required") } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } normalized := msg attachments, err := normalizeAttachmentRefs(msg.Message.Attachments, cfg.ChannelType) @@ -335,7 +336,7 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel editor, _ := m.registry.GetMessageEditor(cfg.ChannelType) if strings.TrimSpace(normalized.Message.ID) != "" { if editor == nil { - return fmt.Errorf("channel does not support edit") + return errors.New("channel does not support edit") } var lastErr error for i := 0; i < policy.RetryMax; i++ { @@ -388,7 +389,7 @@ func normalizeAttachmentRefs(attachments []Attachment, defaultPlatform ChannelTy item.SourcePlatform = defaultPlatform.String() } if item.URL == "" && item.PlatformKey == "" && item.ContentHash == "" && item.Base64 == "" { - return nil, fmt.Errorf("attachment reference is required") + return nil, errors.New("attachment reference is required") } normalized = append(normalized, item) } @@ -412,26 +413,26 @@ func validateStreamEvent(registry *Registry, channelType ChannelType, event Stre switch event.Type { case StreamEventStatus: if event.Status == "" { - return fmt.Errorf("stream status is required") + return errors.New("stream status is required") } case StreamEventDelta: if !caps.Streaming && !caps.BlockStreaming { - return fmt.Errorf("channel does not support streaming") + return errors.New("channel does not support streaming") } case StreamEventPhaseStart, StreamEventPhaseEnd: if !caps.Streaming && !caps.BlockStreaming { - return fmt.Errorf("channel does not support streaming") + return errors.New("channel does not support streaming") } case StreamEventToolCallStart, StreamEventToolCallEnd: if !caps.Streaming && !caps.BlockStreaming { - return fmt.Errorf("channel does not support streaming") + return errors.New("channel does not support streaming") } if event.ToolCall == nil { - return fmt.Errorf("stream tool call payload is required") + return errors.New("stream tool call payload is required") } case StreamEventAttachment: if len(event.Attachments) == 0 { - return fmt.Errorf("stream attachments are required") + return errors.New("stream attachments are required") } if _, err := normalizeAttachmentRefs(event.Attachments, channelType); err != nil { return err @@ -440,11 +441,11 @@ func validateStreamEvent(registry *Registry, channelType ChannelType, event Stre return nil case StreamEventProcessingFailed: if strings.TrimSpace(event.Error) == "" { - return fmt.Errorf("processing failure error is required") + return errors.New("processing failure error is required") } case StreamEventFinal: if event.Final == nil { - return fmt.Errorf("stream final payload is required") + return errors.New("stream final payload is required") } if err := validateMessageCapabilities(registry, channelType, event.Final.Message); err != nil { return err @@ -454,7 +455,7 @@ func validateStreamEvent(registry *Registry, channelType ChannelType, event Stre } case StreamEventError: if strings.TrimSpace(event.Error) == "" { - return fmt.Errorf("stream error is required") + return errors.New("stream error is required") } default: return fmt.Errorf("unsupported stream event type: %s", event.Type) @@ -484,7 +485,7 @@ type managerReplySender struct { func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) error { if s.manager == nil { - return fmt.Errorf("channel manager not configured") + return errors.New("channel manager not configured") } policy := s.manager.resolveOutboundPolicy(s.channelType) outbound, err := buildOutboundMessages(msg, policy) @@ -501,18 +502,18 @@ func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) erro func (s *managerReplySender) OpenStream(ctx context.Context, target string, opts StreamOptions) (OutboundStream, error) { if s.manager == nil { - return nil, fmt.Errorf("channel manager not configured") + return nil, errors.New("channel manager not configured") } if s.streamSender == nil { - return nil, fmt.Errorf("channel stream sender not configured") + return nil, errors.New("channel stream sender not configured") } target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("target is required") + return nil, errors.New("target is required") } caps, _ := s.manager.registry.GetCapabilities(s.channelType) if !caps.Streaming && !caps.BlockStreaming { - return nil, fmt.Errorf("channel does not support streaming") + return nil, errors.New("channel does not support streaming") } stream, err := s.streamSender.OpenStream(ctx, s.config, target, opts) if err != nil { @@ -548,7 +549,7 @@ type managerOutboundStream struct { func (s *managerOutboundStream) Push(ctx context.Context, event StreamEvent) error { if s.manager == nil || s.stream == nil { - return fmt.Errorf("stream is not configured") + return errors.New("stream is not configured") } if err := validateStreamEvent(s.manager.registry, s.channelType, event); err != nil { return err @@ -846,7 +847,7 @@ func (s *managerOutboundStream) sendChunkedFinal(ctx context.Context, msg Messag func (s *managerOutboundStream) Close(ctx context.Context) error { if s.stream == nil { - return fmt.Errorf("stream is not configured") + return errors.New("stream is not configured") } return s.stream.Close(ctx) } diff --git a/internal/channel/outbound_test.go b/internal/channel/outbound_test.go index a7d04c06..c0675bea 100644 --- a/internal/channel/outbound_test.go +++ b/internal/channel/outbound_test.go @@ -2,7 +2,7 @@ package channel import ( "context" - "fmt" + "errors" "strings" "sync" "testing" @@ -119,7 +119,7 @@ func (r *recordingStream) Push(_ context.Context, event StreamEvent) error { return nil } -func (r *recordingStream) Close(_ context.Context) error { return nil } +func (*recordingStream) Close(_ context.Context) error { return nil } func (r *recordingStream) Events() []StreamEvent { r.mu.Lock() @@ -133,11 +133,11 @@ type failingFinalStream struct { recordingStream } -func (f *failingFinalStream) Push(_ context.Context, event StreamEvent) error { +func (f *failingFinalStream) Push(ctx context.Context, event StreamEvent) error { if event.Type == StreamEventFinal { return context.DeadlineExceeded } - return f.recordingStream.Push(context.Background(), event) + return f.recordingStream.Push(ctx, event) } func newChunkingTestStream(t *testing.T, chunkLimit int) (*managerOutboundStream, *recordingStream, *[]OutboundMessage) { @@ -540,7 +540,7 @@ func (r *reopenableStream) current() *recordingStream { func (r *reopenableStream) reopen(_ context.Context) (OutboundStream, error) { r.idx++ if r.idx >= len(r.streams) { - return nil, fmt.Errorf("no more streams") + return nil, errors.New("no more streams") } return r.streams[r.idx], nil } diff --git a/internal/channel/registry.go b/internal/channel/registry.go index 439f6adf..f5bea398 100644 --- a/internal/channel/registry.go +++ b/internal/channel/registry.go @@ -2,6 +2,7 @@ package channel import ( "context" + "errors" "fmt" "strings" "sync" @@ -26,11 +27,11 @@ func NewRegistry() *Registry { // Register adds an adapter to the registry. func (r *Registry) Register(adapter Adapter) error { if adapter == nil { - return fmt.Errorf("adapter is nil") + return errors.New("adapter is nil") } ct := normalizeChannelType(adapter.Type().String()) if ct == "" { - return fmt.Errorf("channel type is required") + return errors.New("channel type is required") } r.mu.Lock() defer r.mu.Unlock() diff --git a/internal/channel/registry_test.go b/internal/channel/registry_test.go index 4a45ab5b..e32363fc 100644 --- a/internal/channel/registry_test.go +++ b/internal/channel/registry_test.go @@ -14,25 +14,25 @@ const dirTestChannelType = channel.ChannelType("dir-test") // dirMockAdapter implements Adapter and ChannelDirectoryAdapter for registry DirectoryAdapter tests. type dirMockAdapter struct{} -func (a *dirMockAdapter) Type() channel.ChannelType { return dirTestChannelType } +func (*dirMockAdapter) Type() channel.ChannelType { return dirTestChannelType } -func (a *dirMockAdapter) Descriptor() channel.Descriptor { +func (*dirMockAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{Type: dirTestChannelType, DisplayName: "DirTest"} } -func (a *dirMockAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*dirMockAdapter) ListPeers(_ context.Context, _ channel.ChannelConfig, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (a *dirMockAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*dirMockAdapter) ListGroups(_ context.Context, _ channel.ChannelConfig, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (a *dirMockAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (*dirMockAdapter) ListGroupMembers(_ context.Context, _ channel.ChannelConfig, _ string, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (a *dirMockAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { +func (*dirMockAdapter) ResolveEntry(_ context.Context, _ channel.ChannelConfig, _ string, _ channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { return channel.DirectoryEntry{}, nil } @@ -66,15 +66,15 @@ func TestDirectoryAdapter_UnknownType(t *testing.T) { type attachmentResolverMockAdapter struct{} -func (a *attachmentResolverMockAdapter) Type() channel.ChannelType { +func (*attachmentResolverMockAdapter) Type() channel.ChannelType { return channel.ChannelType("attachment-test") } -func (a *attachmentResolverMockAdapter) Descriptor() channel.Descriptor { +func (*attachmentResolverMockAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{Type: channel.ChannelType("attachment-test"), DisplayName: "AttachmentTest"} } -func (a *attachmentResolverMockAdapter) ResolveAttachment(ctx context.Context, cfg channel.ChannelConfig, attachment channel.Attachment) (channel.AttachmentPayload, error) { +func (*attachmentResolverMockAdapter) ResolveAttachment(_ context.Context, _ channel.ChannelConfig, _ channel.Attachment) (channel.AttachmentPayload, error) { return channel.AttachmentPayload{ Reader: io.NopCloser(strings.NewReader("payload")), Mime: "text/plain", diff --git a/internal/channel/route/service.go b/internal/channel/route/service.go index cd0bedd1..98f345d9 100644 --- a/internal/channel/route/service.go +++ b/internal/channel/route/service.go @@ -3,6 +3,7 @@ package route import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -202,7 +203,7 @@ func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput) } if s.conversation == nil { - return ResolveConversationResult{}, fmt.Errorf("conversation service not configured") + return ResolveConversationResult{}, errors.New("conversation service not configured") } kind := determineConversationKind(input.ThreadID, input.ConversationType) @@ -360,9 +361,7 @@ func parseJSONMap(data []byte) map[string]any { return nil } var m map[string]any - if err := json.Unmarshal(data, &m); err != nil { - slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err)) - } + _ = json.Unmarshal(data, &m) return m } diff --git a/internal/channel/service.go b/internal/channel/service.go index f83ad993..c8d298b3 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -35,10 +35,10 @@ func NewStore(queries *sqlc.Queries, registry *Registry) *Store { // UpsertConfig creates or updates a bot's channel configuration. func (s *Store) UpsertConfig(ctx context.Context, botID string, channelType ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { if s.queries == nil { - return ChannelConfig{}, fmt.Errorf("channel queries not configured") + return ChannelConfig{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelConfig{}, fmt.Errorf("channel type is required") + return ChannelConfig{}, errors.New("channel type is required") } normalized, err := s.registry.NormalizeConfig(channelType, req.Credentials) if err != nil { @@ -110,10 +110,10 @@ func (s *Store) UpsertConfig(ctx context.Context, botID string, channelType Chan // DeleteConfig removes a bot's channel configuration. func (s *Store) DeleteConfig(ctx context.Context, botID string, channelType ChannelType) error { if s.queries == nil { - return fmt.Errorf("channel queries not configured") + return errors.New("channel queries not configured") } if channelType == "" { - return fmt.Errorf("channel type is required") + return errors.New("channel type is required") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -128,10 +128,10 @@ func (s *Store) DeleteConfig(ctx context.Context, botID string, channelType Chan // UpdateConfigDisabled updates only the disabled flag for a bot channel config and returns latest config. func (s *Store) UpdateConfigDisabled(ctx context.Context, botID string, channelType ChannelType, disabled bool) (ChannelConfig, error) { if s.queries == nil { - return ChannelConfig{}, fmt.Errorf("channel queries not configured") + return ChannelConfig{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelConfig{}, fmt.Errorf("channel type is required") + return ChannelConfig{}, errors.New("channel type is required") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -154,10 +154,10 @@ func (s *Store) UpdateConfigDisabled(ctx context.Context, botID string, channelT // UpsertChannelIdentityConfig creates or updates a channel identity's channel binding. func (s *Store) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { if s.queries == nil { - return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") + return ChannelIdentityBinding{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") + return ChannelIdentityBinding{}, errors.New("channel type is required") } normalized, err := s.registry.NormalizeUserConfig(channelType, req.Config) if err != nil { @@ -186,10 +186,10 @@ func (s *Store) UpsertChannelIdentityConfig(ctx context.Context, channelIdentity // For configless channel types, a synthetic config is returned. func (s *Store) ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { if s.queries == nil { - return ChannelConfig{}, fmt.Errorf("channel queries not configured") + return ChannelConfig{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelConfig{}, fmt.Errorf("channel type is required") + return ChannelConfig{}, errors.New("channel type is required") } if s.registry.IsConfigless(channelType) { return ChannelConfig{ @@ -218,7 +218,7 @@ func (s *Store) ResolveEffectiveConfig(ctx context.Context, botID string, channe // ListConfigsByType returns all channel configurations of the given type. func (s *Store) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { if s.queries == nil { - return nil, fmt.Errorf("channel queries not configured") + return nil, errors.New("channel queries not configured") } if s.registry.IsConfigless(channelType) { return []ChannelConfig{}, nil @@ -241,10 +241,10 @@ func (s *Store) ListConfigsByType(ctx context.Context, channelType ChannelType) // GetChannelIdentityConfig returns the channel identity's channel binding for the given channel type. func (s *Store) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { if s.queries == nil { - return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") + return ChannelIdentityBinding{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") + return ChannelIdentityBinding{}, errors.New("channel type is required") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -256,7 +256,7 @@ func (s *Store) GetChannelIdentityConfig(ctx context.Context, channelIdentityID }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") + return ChannelIdentityBinding{}, errors.New("channel user config not found") } return ChannelIdentityBinding{}, err } @@ -277,7 +277,7 @@ func (s *Store) GetChannelIdentityConfig(ctx context.Context, channelIdentityID // ListChannelIdentityConfigsByType returns all channel identity bindings for the given channel type. func (s *Store) ListChannelIdentityConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelIdentityBinding, error) { if s.queries == nil { - return nil, fmt.Errorf("channel queries not configured") + return nil, errors.New("channel queries not configured") } rows, err := s.queries.ListUserChannelBindingsByPlatform(ctx, channelType.String()) if err != nil { @@ -308,7 +308,7 @@ func (s *Store) ResolveChannelIdentityBinding(ctx context.Context, channelType C return row.ChannelIdentityID, nil } } - return "", fmt.Errorf("channel user binding not found") + return "", errors.New("channel user binding not found") } func normalizeChannelConfigFromRow(row sqlc.BotChannelConfig) (ChannelConfig, error) { diff --git a/internal/config/config.go b/internal/config/config.go index c98edc78..92a42299 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,8 +1,8 @@ package config import ( - "fmt" "os" + "strconv" "strings" "github.com/BurntSushi/toml" @@ -52,12 +52,12 @@ type ServerConfig struct { type AdminConfig struct { Username string `toml:"username"` - Password string `toml:"password"` + Password string `toml:"password" json:"-"` Email string `toml:"email"` } type AuthConfig struct { - JWTSecret string `toml:"jwt_secret"` + JWTSecret string `toml:"jwt_secret" json:"-"` JWTExpiresIn string `toml:"jwt_expires_in"` } @@ -113,14 +113,14 @@ type PostgresConfig struct { Host string `toml:"host"` Port int `toml:"port"` User string `toml:"user"` - Password string `toml:"password"` + Password string `toml:"password" json:"-"` Database string `toml:"database"` SSLMode string `toml:"sslmode"` } type QdrantConfig struct { BaseURL string `toml:"base_url"` - APIKey string `toml:"api_key"` + APIKey string `toml:"api_key" json:"-"` TimeoutSeconds int `toml:"timeout_seconds"` } @@ -138,7 +138,7 @@ func (c AgentGatewayConfig) BaseURL() string { if port == 0 { port = 8081 } - return "http://" + host + ":" + fmt.Sprint(port) + return "http://" + host + ":" + strconv.Itoa(port) } func Load(path string) (Config, error) { diff --git a/internal/containerd/factory.go b/internal/containerd/factory.go index 4f23e7e6..1730acc6 100644 --- a/internal/containerd/factory.go +++ b/internal/containerd/factory.go @@ -6,6 +6,7 @@ import ( "log/slog" containerd "github.com/containerd/containerd/v2/client" + "github.com/memohai/memoh/internal/config" ) diff --git a/internal/containerd/network.go b/internal/containerd/network.go index 95cd7beb..b08b7be3 100644 --- a/internal/containerd/network.go +++ b/internal/containerd/network.go @@ -5,13 +5,14 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" "github.com/containerd/containerd/v2/client" gocni "github.com/containerd/go-cni" ) -func setupCNINetwork(ctx context.Context, task client.Task, containerID string, CNIBinDir string, CNIConfDir string) (string, error) { +func setupCNINetwork(ctx context.Context, task client.Task, containerID string, cniBinDir string, cniConfDir string) (string, error) { if task == nil { return "", ErrInvalidArgument } @@ -27,20 +28,20 @@ func setupCNINetwork(ctx context.Context, task client.Task, containerID string, return "", fmt.Errorf("task pid not available for %s", containerID) } - if _, err := os.Stat(CNIConfDir); err != nil { - return "", fmt.Errorf("cni config dir missing: %s: %w", CNIConfDir, err) + if _, err := os.Stat(cniConfDir); err != nil { + return "", fmt.Errorf("cni config dir missing: %s: %w", cniConfDir, err) } - if _, err := os.Stat(CNIBinDir); err != nil { - return "", fmt.Errorf("cni bin dir missing: %s: %w", CNIBinDir, err) + if _, err := os.Stat(cniBinDir); err != nil { + return "", fmt.Errorf("cni bin dir missing: %s: %w", cniBinDir, err) } - netnsPath := filepath.Join("/proc", fmt.Sprint(pid), "ns", "net") + netnsPath := filepath.Join("/proc", strconv.FormatUint(uint64(pid), 10), "ns", "net") if _, err := os.Stat(netnsPath); err != nil { return "", fmt.Errorf("netns not found: %s: %w", netnsPath, err) } cni, err := gocni.New( - gocni.WithPluginDir([]string{CNIBinDir}), - gocni.WithPluginConfDir(CNIConfDir), + gocni.WithPluginDir([]string{cniBinDir}), + gocni.WithPluginConfDir(cniConfDir), ) if err != nil { return "", err @@ -82,7 +83,7 @@ func extractIP(result *gocni.Result) string { return "" } -func removeCNINetwork(ctx context.Context, task client.Task, containerID string, CNIBinDir string, CNIConfDir string) error { +func removeCNINetwork(ctx context.Context, task client.Task, containerID string, cniBinDir string, cniConfDir string) error { if task == nil { return ErrInvalidArgument } @@ -98,21 +99,21 @@ func removeCNINetwork(ctx context.Context, task client.Task, containerID string, return fmt.Errorf("task pid not available for %s", containerID) } - if _, err := os.Stat(CNIConfDir); err != nil { - return fmt.Errorf("cni config dir missing: %s: %w", CNIConfDir, err) + if _, err := os.Stat(cniConfDir); err != nil { + return fmt.Errorf("cni config dir missing: %s: %w", cniConfDir, err) } - if _, err := os.Stat(CNIBinDir); err != nil { - return fmt.Errorf("cni bin dir missing: %s: %w", CNIBinDir, err) + if _, err := os.Stat(cniBinDir); err != nil { + return fmt.Errorf("cni bin dir missing: %s: %w", cniBinDir, err) } - netnsPath := filepath.Join("/proc", fmt.Sprint(pid), "ns", "net") + netnsPath := filepath.Join("/proc", strconv.FormatUint(uint64(pid), 10), "ns", "net") if _, err := os.Stat(netnsPath); err != nil { return fmt.Errorf("netns not found: %s: %w", netnsPath, err) } cni, err := gocni.New( - gocni.WithPluginDir([]string{CNIBinDir}), - gocni.WithPluginConfDir(CNIConfDir), + gocni.WithPluginDir([]string{cniBinDir}), + gocni.WithPluginConfDir(cniConfDir), ) if err != nil { return err diff --git a/internal/containerd/resolv.go b/internal/containerd/resolv.go index 776b763c..de6ddbbe 100644 --- a/internal/containerd/resolv.go +++ b/internal/containerd/resolv.go @@ -22,7 +22,7 @@ func ResolveConfSource(dataDir string) (string, error) { return systemdResolvConf, nil } - if err := os.MkdirAll(dataDir, 0o755); err != nil { + if err := os.MkdirAll(dataDir, 0o750); err != nil { return "", err } fallbackPath := filepath.Join(dataDir, "resolv.conf") @@ -31,7 +31,7 @@ func ResolveConfSource(dataDir string) (string, error) { } else if !os.IsNotExist(err) { return "", err } - if err := os.WriteFile(fallbackPath, []byte(fallbackResolv), 0o644); err != nil { + if err := os.WriteFile(fallbackPath, []byte(fallbackResolv), 0o600); err != nil { return "", err } return fallbackPath, nil diff --git a/internal/containerd/service.go b/internal/containerd/service.go index 702447e4..91113727 100644 --- a/internal/containerd/service.go +++ b/internal/containerd/service.go @@ -19,9 +19,10 @@ import ( "github.com/containerd/containerd/v2/pkg/namespaces" "github.com/containerd/containerd/v2/pkg/oci" "github.com/containerd/errdefs" - "github.com/memohai/memoh/internal/config" "github.com/opencontainers/image-spec/identity" "github.com/opencontainers/runtime-spec/specs-go" + + "github.com/memohai/memoh/internal/config" ) var ( @@ -222,7 +223,7 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine if err != nil { return ContainerInfo{}, err } - defer done(ctx) + defer func() { _ = done(ctx) }() image, err := s.getImageWithFallback(ctx, req.ImageRef) if err != nil { pullOpts := &PullImageOptions{ @@ -288,13 +289,13 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine return toContainerInfo(ctx, ctrObj) } -func (s *DefaultService) snapshotParentFromLayers(ctx context.Context, image containerd.Image) (string, error) { +func (*DefaultService) snapshotParentFromLayers(ctx context.Context, image containerd.Image) (string, error) { diffIDs, err := image.RootFS(ctx) if err != nil { return "", fmt.Errorf("read image rootfs: %w", err) } if len(diffIDs) == 0 { - return "", fmt.Errorf("image has no layers") + return "", errors.New("image has no layers") } chainIDs := identity.ChainIDs(diffIDs) return chainIDs[len(chainIDs)-1].String(), nil @@ -422,7 +423,7 @@ func (s *DefaultService) DeleteContainer(ctx context.Context, id string, opts *D return container.Delete(ctx, deleteOpts...) } -func (s *DefaultService) StartContainer(ctx context.Context, containerID string, opts *StartTaskOptions) error { +func (s *DefaultService) StartContainer(ctx context.Context, containerID string, _ *StartTaskOptions) error { if containerID == "" { return ErrInvalidArgument } @@ -615,7 +616,7 @@ func (s *DefaultService) ListSnapshots(ctx context.Context, snapshotter string) } ctx = s.withNamespace(ctx) var infos []SnapshotInfo - if err := s.client.SnapshotService(snapshotter).Walk(ctx, func(ctx context.Context, info snapshots.Info) error { + if err := s.client.SnapshotService(snapshotter).Walk(ctx, func(_ context.Context, info snapshots.Info) error { infos = append(infos, SnapshotInfo{ Name: info.Name, Parent: info.Parent, diff --git a/internal/containerd/service_apple.go b/internal/containerd/service_apple.go index f5fc1823..31e414c7 100644 --- a/internal/containerd/service_apple.go +++ b/internal/containerd/service_apple.go @@ -8,7 +8,6 @@ import ( "path/filepath" "strings" "sync" - "syscall" "github.com/memohai/acgo" "github.com/memohai/acgo/socktainer" @@ -97,7 +96,7 @@ func (s *AppleService) Close() error { // Images // --------------------------------------------------------------------------- -func (s *AppleService) PullImage(ctx context.Context, ref string, opts *PullImageOptions) (ImageInfo, error) { +func (s *AppleService) PullImage(ctx context.Context, ref string, _ *PullImageOptions) (ImageInfo, error) { if ref == "" { return ImageInfo{}, ErrInvalidArgument } @@ -140,7 +139,7 @@ func (s *AppleService) ListImages(ctx context.Context) ([]ImageInfo, error) { return out, nil } -func (s *AppleService) DeleteImage(ctx context.Context, ref string, opts *DeleteImageOptions) error { +func (s *AppleService) DeleteImage(ctx context.Context, ref string, _ *DeleteImageOptions) error { if ref == "" { return ErrInvalidArgument } @@ -255,7 +254,7 @@ func (s *AppleService) ListContainersByLabel(ctx context.Context, key, value str // Task / process lifecycle // --------------------------------------------------------------------------- -func (s *AppleService) StartContainer(ctx context.Context, containerID string, opts *StartTaskOptions) error { +func (s *AppleService) StartContainer(ctx context.Context, containerID string, _ *StartTaskOptions) error { if containerID == "" { return ErrInvalidArgument } @@ -287,7 +286,7 @@ func (s *AppleService) StopContainer(ctx context.Context, containerID string, op var stopOpts []acgo.StopOpt stopOpts = append(stopOpts, acgo.WithStopTimeout(timeout)) if opts != nil && opts.Signal != 0 { - stopOpts = append(stopOpts, acgo.WithStopSignal(syscall.Signal(opts.Signal).String())) + stopOpts = append(stopOpts, acgo.WithStopSignal(opts.Signal.String())) } if err := ctr.Stop(ctx, stopOpts...); err != nil && opts != nil && opts.Force { return ctr.Kill(ctx) @@ -295,7 +294,7 @@ func (s *AppleService) StopContainer(ctx context.Context, containerID string, op return nil } -func (s *AppleService) DeleteTask(context.Context, string, *DeleteTaskOptions) error { +func (*AppleService) DeleteTask(context.Context, string, *DeleteTaskOptions) error { return nil } @@ -355,28 +354,32 @@ func (s *AppleService) ListTasks(ctx context.Context, opts *ListTasksOptions) ([ // Network (no-op — Apple Container handles networking natively) // --------------------------------------------------------------------------- -func (s *AppleService) SetupNetwork(context.Context, NetworkSetupRequest) (NetworkResult, error) { +func (*AppleService) SetupNetwork(context.Context, NetworkSetupRequest) (NetworkResult, error) { return NetworkResult{}, nil } -func (s *AppleService) RemoveNetwork(context.Context, NetworkSetupRequest) error { return nil } +func (*AppleService) RemoveNetwork(context.Context, NetworkSetupRequest) error { return nil } // --------------------------------------------------------------------------- // Snapshots (not supported on Apple Container) // --------------------------------------------------------------------------- -func (s *AppleService) CommitSnapshot(context.Context, string, string, string) error { +func (*AppleService) CommitSnapshot(context.Context, string, string, string) error { return ErrNotSupported } -func (s *AppleService) ListSnapshots(context.Context, string) ([]SnapshotInfo, error) { + +func (*AppleService) ListSnapshots(context.Context, string) ([]SnapshotInfo, error) { return nil, ErrNotSupported } -func (s *AppleService) PrepareSnapshot(context.Context, string, string, string) error { + +func (*AppleService) PrepareSnapshot(context.Context, string, string, string) error { return ErrNotSupported } -func (s *AppleService) CreateContainerFromSnapshot(context.Context, CreateContainerRequest) (ContainerInfo, error) { + +func (*AppleService) CreateContainerFromSnapshot(context.Context, CreateContainerRequest) (ContainerInfo, error) { return ContainerInfo{}, ErrNotSupported } -func (s *AppleService) SnapshotMounts(context.Context, string, string) ([]MountInfo, error) { + +func (*AppleService) SnapshotMounts(context.Context, string, string) ([]MountInfo, error) { return nil, ErrNotSupported } diff --git a/internal/containerd/types.go b/internal/containerd/types.go index a5b60a2c..a98ff169 100644 --- a/internal/containerd/types.go +++ b/internal/containerd/types.go @@ -104,4 +104,3 @@ type NetworkSetupRequest struct { type NetworkResult struct { IP string } - diff --git a/internal/conversation/flow/capability_policy_test.go b/internal/conversation/flow/capability_policy_test.go index 139651bc..917a229d 100644 --- a/internal/conversation/flow/capability_policy_test.go +++ b/internal/conversation/flow/capability_policy_test.go @@ -26,7 +26,7 @@ func TestRouteAttachmentsByCapability_TextOnly(t *testing.T) { {Type: "video", Transport: gatewayTransportToolFileRef, Payload: "/data/video.mp4"}, } result := routeAttachmentsByCapability(modalities, attachments) - assert.Len(t, result.Native, 0) + assert.Empty(t, result.Native) assert.Len(t, result.Fallback, 2) } @@ -49,7 +49,7 @@ func TestRouteAttachmentsByCapability_ImagePathOnlyFallsBack(t *testing.T) { {Type: "image", Transport: gatewayTransportToolFileRef, Payload: "/data/image.png"}, } result := routeAttachmentsByCapability(modalities, attachments) - assert.Len(t, result.Native, 0) + assert.Empty(t, result.Native) assert.Len(t, result.Fallback, 1) assert.Equal(t, "image", result.Fallback[0].Type) } @@ -61,7 +61,7 @@ func TestRouteAttachmentsByCapability_ImageURLIsNative(t *testing.T) { } result := routeAttachmentsByCapability(modalities, attachments) assert.Len(t, result.Native, 1) - assert.Len(t, result.Fallback, 0) + assert.Empty(t, result.Fallback) assert.Equal(t, "image", result.Native[0].Type) } @@ -71,14 +71,14 @@ func TestRouteAttachmentsByCapability_UnknownType(t *testing.T) { {Type: "hologram", Transport: gatewayTransportToolFileRef, Payload: "/data/holo.dat"}, } result := routeAttachmentsByCapability(modalities, attachments) - assert.Len(t, result.Native, 0) + assert.Empty(t, result.Native) assert.Len(t, result.Fallback, 1) } func TestRouteAttachmentsByCapability_Empty(t *testing.T) { result := routeAttachmentsByCapability([]string{"text"}, nil) - assert.Len(t, result.Native, 0) - assert.Len(t, result.Fallback, 0) + assert.Empty(t, result.Native) + assert.Empty(t, result.Fallback) } func TestAttachmentsToAny(t *testing.T) { diff --git a/internal/conversation/flow/email_gateway.go b/internal/conversation/flow/email_gateway.go index 7bd91363..396206ed 100644 --- a/internal/conversation/flow/email_gateway.go +++ b/internal/conversation/flow/email_gateway.go @@ -2,6 +2,7 @@ package flow import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -34,7 +35,7 @@ func NewEmailChatGateway(resolver *Resolver, queries *sqlc.Queries, jwtSecret st func (g *EmailChatGateway) TriggerBotChat(ctx context.Context, botID, content string) error { if g == nil || g.resolver == nil { - return fmt.Errorf("chat resolver not configured") + return errors.New("chat resolver not configured") } ownerUserID, err := g.resolveBotOwner(ctx, botID) @@ -75,14 +76,14 @@ func (g *EmailChatGateway) resolveBotOwner(ctx context.Context, botID string) (s } ownerID := bot.OwnerUserID.String() if ownerID == "" { - return "", fmt.Errorf("bot owner not found") + return "", errors.New("bot owner not found") } return ownerID, nil } func (g *EmailChatGateway) generateToken(userID string) (string, error) { if strings.TrimSpace(g.jwtSecret) == "" { - return "", fmt.Errorf("jwt secret not configured") + return "", errors.New("jwt secret not configured") } signed, _, err := auth.GenerateToken(userID, g.jwtSecret, emailTriggerTokenTTL) if err != nil { diff --git a/internal/conversation/flow/heartbeat_gateway.go b/internal/conversation/flow/heartbeat_gateway.go index cdaaec2d..6d7be78d 100644 --- a/internal/conversation/flow/heartbeat_gateway.go +++ b/internal/conversation/flow/heartbeat_gateway.go @@ -2,7 +2,7 @@ package flow import ( "context" - "fmt" + "errors" "github.com/memohai/memoh/internal/heartbeat" ) @@ -20,7 +20,7 @@ func NewHeartbeatGateway(resolver *Resolver) *HeartbeatGateway { // TriggerHeartbeat delegates a heartbeat trigger to the chat Resolver. func (g *HeartbeatGateway) TriggerHeartbeat(ctx context.Context, botID string, payload heartbeat.TriggerPayload, token string) (heartbeat.TriggerResult, error) { if g == nil || g.resolver == nil { - return heartbeat.TriggerResult{}, fmt.Errorf("chat resolver not configured") + return heartbeat.TriggerResult{}, errors.New("chat resolver not configured") } return g.resolver.TriggerHeartbeat(ctx, botID, payload, token) } diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index df903340..3d6e7f7e 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -148,7 +148,7 @@ type gatewayModelConfig struct { ModelID string `json:"modelId"` ClientType string `json:"clientType"` Input []string `json:"input"` - APIKey string `json:"apiKey"` + APIKey string `json:"apiKey"` //nolint:gosec // intentional: forwarded to agent gateway for model authentication BaseURL string `json:"baseUrl"` Reasoning *gatewayReasoningConfig `json:"reasoning,omitempty"` } @@ -160,7 +160,7 @@ type gatewayIdentity struct { DisplayName string `json:"displayName"` CurrentPlatform string `json:"currentPlatform,omitempty"` ConversationType string `json:"conversationType,omitempty"` - SessionToken string `json:"sessionToken,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` //nolint:gosec // intentional: session token forwarded to agent gateway for channel reply routing } type gatewaySkill struct { @@ -286,13 +286,13 @@ type resolvedContext struct { func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) { if strings.TrimSpace(req.Query) == "" && len(req.Attachments) == 0 { - return resolvedContext{}, fmt.Errorf("query or attachments is required") + return resolvedContext{}, errors.New("query or attachments is required") } if strings.TrimSpace(req.BotID) == "" { - return resolvedContext{}, fmt.Errorf("bot id is required") + return resolvedContext{}, errors.New("bot id is required") } if strings.TrimSpace(req.ChatID) == "" { - return resolvedContext{}, fmt.Errorf("chat id is required") + return resolvedContext{}, errors.New("chat id is required") } skipHistory := req.MaxContextLoadTime < 0 @@ -360,7 +360,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r return resolvedContext{}, loadErr } loaded = pruneHistoryForGateway(loaded) - messages = trimMessagesByTokens(loaded, historyBudget) + messages = trimMessagesByTokens(r.logger, loaded, historyBudget) r.logger.Debug("context trim result", slog.Int("loaded_messages", len(loaded)), slog.Int("kept_messages", len(messages)), @@ -508,10 +508,10 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv // TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint. func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { if strings.TrimSpace(botID) == "" { - return fmt.Errorf("bot id is required") + return errors.New("bot id is required") } if strings.TrimSpace(payload.Command) == "" { - return fmt.Errorf("schedule command is required") + return errors.New("schedule command is required") } req := conversation.ChatRequest{ @@ -554,7 +554,7 @@ func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload sc // TriggerHeartbeat executes a heartbeat check through the agent gateway trigger-heartbeat endpoint. func (r *Resolver) TriggerHeartbeat(ctx context.Context, botID string, payload heartbeat.TriggerPayload, token string) (heartbeat.TriggerResult, error) { if strings.TrimSpace(botID) == "" { - return heartbeat.TriggerResult{}, fmt.Errorf("bot id is required") + return heartbeat.TriggerResult{}, errors.New("bot id is required") } // If a dedicated heartbeat model is configured, use it instead of the @@ -691,11 +691,11 @@ func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token s httpReq.Header.Set("Authorization", token) } - resp, err := r.httpClient.Do(httpReq) + resp, err := r.httpClient.Do(httpReq) //nolint:gosec // G704: URL is from operator-configured agent gateway, not user input if err != nil { return gatewayResponse{}, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, err := io.ReadAll(resp.Body) if err != nil { @@ -727,11 +727,11 @@ func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerSched httpReq.Header.Set("Authorization", token) } - resp, err := r.httpClient.Do(httpReq) + resp, err := r.httpClient.Do(httpReq) //nolint:gosec // G704: URL is from operator-configured agent gateway, not user input if err != nil { return gatewayResponse{}, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, err := io.ReadAll(resp.Body) if err != nil { @@ -763,11 +763,11 @@ func (r *Resolver) postTriggerHeartbeat(ctx context.Context, payload triggerHear httpReq.Header.Set("Authorization", token) } - resp, err := r.httpClient.Do(httpReq) + resp, err := r.httpClient.Do(httpReq) //nolint:gosec // G704: URL is from operator-configured agent gateway, not user input if err != nil { return gatewayResponse{}, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, err := io.ReadAll(resp.Body) if err != nil { @@ -803,12 +803,12 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req c httpReq.Header.Set("Authorization", req.Token) } - resp, err := r.streamingClient.Do(httpReq) + resp, err := r.streamingClient.Do(httpReq) //nolint:gosec // G704: URL is from operator-configured agent gateway, not user input if err != nil { r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err)) return err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { errBody, _ := io.ReadAll(resp.Body) @@ -982,13 +982,14 @@ func (r *Resolver) prepareGatewayAttachments(ctx context.Context, req conversati transport = gatewayTransportInlineDataURL } else { rawURL := strings.TrimSpace(raw.URL) - if isDataURL(rawURL) { + switch { + case isDataURL(rawURL): payload = rawURL transport = gatewayTransportInlineDataURL - } else if isLikelyPublicURL(rawURL) { + case isLikelyPublicURL(rawURL): payload = rawURL transport = gatewayTransportPublicURL - } else if rawURL != "" && fallbackPath == "" { + case rawURL != "" && fallbackPath == "": fallbackPath = rawURL } } @@ -1080,7 +1081,7 @@ func (r *Resolver) inlineImageAttachmentAssetIfNeeded(ctx context.Context, botID func (r *Resolver) inlineAssetAsDataURL(ctx context.Context, botID, contentHash, attachmentType, fallbackMime string) (string, string, error) { if r == nil || r.assetLoader == nil { - return "", "", fmt.Errorf("gateway asset loader not configured") + return "", "", errors.New("gateway asset loader not configured") } reader, assetMime, err := r.assetLoader.OpenForGateway(ctx, botID, contentHash) if err != nil { @@ -1102,15 +1103,15 @@ func (r *Resolver) inlineAssetAsDataURL(ctx context.Context, botID, contentHash, func encodeReaderAsDataURL(reader io.Reader, maxBytes int64, attachmentType, fallbackMime string) (string, string, error) { if reader == nil { - return "", "", fmt.Errorf("reader is required") + return "", "", errors.New("reader is required") } if maxBytes <= 0 { - return "", "", fmt.Errorf("max bytes must be greater than 0") + return "", "", errors.New("max bytes must be greater than 0") } limited := &io.LimitedReader{R: reader, N: maxBytes + 1} head := make([]byte, 512) n, err := limited.Read(head) - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return "", "", fmt.Errorf("read asset: %w", err) } head = head[:n] @@ -1229,7 +1230,7 @@ func estimateMessageTokens(msg conversation.ModelMessage) int { return len(text) / 4 } -func trimMessagesByTokens(messages []messageWithUsage, maxTokens int) []conversation.ModelMessage { +func trimMessagesByTokens(log *slog.Logger, messages []messageWithUsage, maxTokens int) []conversation.ModelMessage { if maxTokens <= 0 || len(messages) == 0 { result := make([]conversation.ModelMessage, len(messages)) for i, m := range messages { @@ -1263,14 +1264,16 @@ func trimMessagesByTokens(messages []messageWithUsage, maxTokens int) []conversa cutoff++ } - slog.Debug("trimMessagesByTokens", - slog.Int("total_messages", len(messages)), - slog.Int("messages_with_usage", messagesWithUsage), - slog.Int("accumulated_output_tokens", totalTokens), - slog.Int("max_tokens", maxTokens), - slog.Int("cutoff_index", cutoff), - slog.Int("kept_messages", len(messages)-cutoff), - ) + if log != nil { + log.Debug("trimMessagesByTokens", + slog.Int("total_messages", len(messages)), + slog.Int("messages_with_usage", messagesWithUsage), + slog.Int("accumulated_output_tokens", totalTokens), + slog.Int("max_tokens", maxTokens), + slog.Int("cutoff_index", cutoff), + slog.Int("kept_messages", len(messages)-cutoff), + ) + } result := make([]conversation.ModelMessage, 0, len(messages)-cutoff) for _, m := range messages[cutoff:] { @@ -1332,7 +1335,7 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.Chat return nil } if strings.TrimSpace(req.BotID) == "" { - return fmt.Errorf("bot id is required for persistence") + return errors.New("bot id is required for persistence") } text := strings.TrimSpace(req.Query) if text == "" && len(req.Attachments) == 0 { @@ -1676,7 +1679,7 @@ func toProviderMessages(messages []conversation.ModelMessage) []memprovider.Mess func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.LlmProvider, error) { if r.modelsService == nil { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured") } modelID := strings.TrimSpace(req.Model) providerFilter := strings.TrimSpace(req.Provider) @@ -1691,7 +1694,7 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq } if modelID == "" { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings") + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("chat model not configured: specify model in request or bot settings") } if providerFilter == "" { @@ -1717,7 +1720,7 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.GetResponse, sqlc.LlmProvider, error) { modelRef := strings.TrimSpace(modelID) if modelRef == "" { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model id is required") + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("model id is required") } // Support both model UUID and model_id slug. UUID-formatted slugs still @@ -1740,7 +1743,7 @@ func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.G resolved: if model.Type != models.ModelTypeChat { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("model is not a chat model") } prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) if err != nil { @@ -1792,7 +1795,7 @@ func (r *Resolver) markInboxRead(ctx context.Context, botID string, ids []string func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) { if r.settingsService == nil { - return settings.Settings{}, fmt.Errorf("settings service not configured") + return settings.Settings{}, errors.New("settings service not configured") } return r.settingsService.GetBot(ctx, botID) } @@ -1841,16 +1844,6 @@ func parseLoopDetectionEnabledFromMetadata(payload []byte) bool { // --- utility --- -func normalizeClientType(clientType string) (string, error) { - ct := strings.ToLower(strings.TrimSpace(clientType)) - switch ct { - case "openai-responses", "openai-completions", "anthropic-messages", "google-generative-ai": - return ct, nil - default: - return "", fmt.Errorf("unsupported agent gateway client type: %s", clientType) - } -} - func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage { cleaned := make([]conversation.ModelMessage, 0, len(messages)) for _, msg := range messages { @@ -1903,15 +1896,6 @@ func dedup(items []string) []string { return result } -func firstNonEmpty(values ...string) string { - for _, v := range values { - if strings.TrimSpace(v) != "" { - return v - } - } - return "" -} - func coalescePositiveInt(values ...int) int { for _, v := range values { if v > 0 { @@ -1944,7 +1928,7 @@ func truncate(s string, n int) string { func parseResolverUUID(id string) (pgtype.UUID, error) { if strings.TrimSpace(id) == "" { - return pgtype.UUID{}, fmt.Errorf("empty id") + return pgtype.UUID{}, errors.New("empty id") } return db.ParseUUID(id) } diff --git a/internal/conversation/flow/resolver_stream_order_test.go b/internal/conversation/flow/resolver_stream_order_test.go index d7d82181..aa3ee2cd 100644 --- a/internal/conversation/flow/resolver_stream_order_test.go +++ b/internal/conversation/flow/resolver_stream_order_test.go @@ -3,7 +3,6 @@ package flow import ( "context" "encoding/json" - "io" "log/slog" "net/http" "net/http/httptest" @@ -19,7 +18,7 @@ type blockingMessageService struct { persistContinue chan struct{} } -func (s *blockingMessageService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { +func (s *blockingMessageService) Persist(_ context.Context, _ messagepkg.PersistInput) (messagepkg.Message, error) { select { case <-s.persistCalled: default: @@ -29,27 +28,27 @@ func (s *blockingMessageService) Persist(ctx context.Context, input messagepkg.P return messagepkg.Message{}, nil } -func (s *blockingMessageService) List(ctx context.Context, botID string) ([]messagepkg.Message, error) { +func (*blockingMessageService) List(_ context.Context, _ string) ([]messagepkg.Message, error) { return nil, nil } -func (s *blockingMessageService) ListSince(ctx context.Context, botID string, since time.Time) ([]messagepkg.Message, error) { +func (*blockingMessageService) ListSince(_ context.Context, _ string, _ time.Time) ([]messagepkg.Message, error) { return nil, nil } -func (s *blockingMessageService) ListActiveSince(ctx context.Context, botID string, since time.Time) ([]messagepkg.Message, error) { +func (*blockingMessageService) ListActiveSince(_ context.Context, _ string, _ time.Time) ([]messagepkg.Message, error) { return nil, nil } -func (s *blockingMessageService) ListLatest(ctx context.Context, botID string, limit int32) ([]messagepkg.Message, error) { +func (*blockingMessageService) ListLatest(_ context.Context, _ string, _ int32) ([]messagepkg.Message, error) { return nil, nil } -func (s *blockingMessageService) ListBefore(ctx context.Context, botID string, before time.Time, limit int32) ([]messagepkg.Message, error) { +func (*blockingMessageService) ListBefore(_ context.Context, _ string, _ time.Time, _ int32) ([]messagepkg.Message, error) { return nil, nil } -func (s *blockingMessageService) DeleteByBot(ctx context.Context, botID string) error { +func (*blockingMessageService) DeleteByBot(_ context.Context, _ string) error { return nil } @@ -94,7 +93,7 @@ func TestStreamChat_PersistsFinalMessagesBeforeForwardingDoneEvent(t *testing.T) r := &Resolver{ messageService: msgSvc, gatewayBaseURL: srv.URL, - logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + logger: slog.New(slog.DiscardHandler), streamingClient: srv.Client(), httpClient: srv.Client(), } @@ -105,7 +104,7 @@ func TestStreamChat_PersistsFinalMessagesBeforeForwardingDoneEvent(t *testing.T) streamDone := make(chan error, 1) go func() { - streamDone <- r.streamChat(context.Background(), payload, req, chunkCh) + streamDone <- r.streamChat(context.Background(), payload, req, chunkCh, "model-test") close(chunkCh) }() diff --git a/internal/conversation/flow/resolver_test.go b/internal/conversation/flow/resolver_test.go index 1f22df2c..6c1ac84a 100644 --- a/internal/conversation/flow/resolver_test.go +++ b/internal/conversation/flow/resolver_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/memohai/memoh/internal/conversation" "github.com/memohai/memoh/internal/models" ) @@ -30,7 +32,7 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) { Messages: []conversation.ModelMessage{{Role: "assistant", Content: conversation.NewTextContent("ok")}}, } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + require.NoError(t, json.NewEncoder(w).Encode(resp)) })) defer srv.Close() @@ -111,7 +113,7 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedAuth = r.Header.Get("Authorization") resp := gatewayResponse{Messages: []conversation.ModelMessage{}} - json.NewEncoder(w).Encode(resp) + require.NoError(t, json.NewEncoder(w).Encode(resp)) })) defer srv.Close() @@ -141,9 +143,10 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) { } func TestPostTriggerSchedule_GatewayError(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("internal error")) + _, err := w.Write([]byte("internal error")) + require.NoError(t, err) })) defer srv.Close() @@ -184,7 +187,7 @@ func TestPrepareGatewayAttachments_InlineAssetToBase64(t *testing.T) { resolver := &Resolver{ logger: slog.Default(), assetLoader: &fakeGatewayAssetLoader{ - openFn: func(ctx context.Context, botID, contentHash string) (io.ReadCloser, string, error) { + openFn: func(_ context.Context, _, contentHash string) (io.ReadCloser, string, error) { if contentHash != "asset-1" { t.Fatalf("unexpected content hash: %s", contentHash) } @@ -291,6 +294,7 @@ func TestStreamChat_AllowsLargeSSEDataLines(t *testing.T) { gatewayRequest{}, conversation.ChatRequest{}, chunkCh, + "model-test", ) if err != nil { t.Fatalf("streamChat returned error: %v", err) @@ -299,7 +303,7 @@ func TestStreamChat_AllowsLargeSSEDataLines(t *testing.T) { select { case chunk := <-chunkCh: if !bytes.Equal(chunk, dataJSON) { - t.Fatalf("unexpected reconstructed payload: got prefix %q", string(chunk[:min(len(chunk), 80)])) + t.Fatalf("unexpected reconstructed payload: got prefix %q", string(chunk[:minInt(len(chunk), 80)])) } default: t.Fatalf("expected at least one streamed chunk") @@ -328,7 +332,7 @@ func TestStreamChat_RejectsOverLimitSSELine(t *testing.T) { } chunkCh := make(chan conversation.StreamChunk, 1) - err := resolver.streamChat(context.Background(), gatewayRequest{}, conversation.ChatRequest{}, chunkCh) + err := resolver.streamChat(context.Background(), gatewayRequest{}, conversation.ChatRequest{}, chunkCh, "model-test") if err == nil { t.Fatalf("expected streamChat to error on oversized SSE line") } @@ -337,7 +341,7 @@ func TestStreamChat_RejectsOverLimitSSELine(t *testing.T) { } } -func min(a, b int) int { +func minInt(a, b int) int { if a < b { return a } @@ -413,7 +417,7 @@ func TestPrepareGatewayAttachments_DetectsImageMimeWhenOctetStream(t *testing.T) resolver := &Resolver{ logger: slog.Default(), assetLoader: &fakeGatewayAssetLoader{ - openFn: func(ctx context.Context, botID, contentHash string) (io.ReadCloser, string, error) { + openFn: func(_ context.Context, _, _ string) (io.ReadCloser, string, error) { return io.NopCloser(bytes.NewReader(jpegBytes)), "application/octet-stream", nil }, }, diff --git a/internal/conversation/flow/resolver_trim_test.go b/internal/conversation/flow/resolver_trim_test.go index 266298cc..136aa233 100644 --- a/internal/conversation/flow/resolver_trim_test.go +++ b/internal/conversation/flow/resolver_trim_test.go @@ -52,7 +52,7 @@ func TestTrimMessagesByTokens_DropsLeadingOrphanTool(t *testing.T) { // Budget 70: assistant(60) fits, adding assistant-tool-call(50) exceeds → // cutoff lands on the tool message which must be skipped. - trimmed := trimMessagesByTokens(messages, 70) + trimmed := trimMessagesByTokens(nil, messages, 70) if len(trimmed) == 0 { t.Fatal("expected non-empty trimmed messages") } @@ -90,7 +90,7 @@ func TestTrimMessagesByTokens_KeepsToolWhenPaired(t *testing.T) { }, } - trimmed := trimMessagesByTokens(messages, 100) + trimmed := trimMessagesByTokens(nil, messages, 100) if len(trimmed) != 2 { t.Fatalf("expected 2 messages, got %d", len(trimmed)) } @@ -107,7 +107,7 @@ func TestTrimMessagesByTokens_NoUsage_KeepsAll(t *testing.T) { {Message: conversation.ModelMessage{Role: "assistant", Content: conversation.NewTextContent("hi")}}, } - trimmed := trimMessagesByTokens(messages, 10) + trimmed := trimMessagesByTokens(nil, messages, 10) if len(trimmed) != 2 { t.Fatalf("messages without outputTokens should all be kept, got %d", len(trimmed)) } diff --git a/internal/conversation/flow/schedule_gateway.go b/internal/conversation/flow/schedule_gateway.go index 4b3c2138..412b4b06 100644 --- a/internal/conversation/flow/schedule_gateway.go +++ b/internal/conversation/flow/schedule_gateway.go @@ -2,7 +2,7 @@ package flow import ( "context" - "fmt" + "errors" "github.com/memohai/memoh/internal/schedule" ) @@ -20,7 +20,7 @@ func NewScheduleGateway(resolver *Resolver) *ScheduleGateway { // TriggerSchedule delegates a schedule trigger to the chat Resolver. func (g *ScheduleGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { if g == nil || g.resolver == nil { - return fmt.Errorf("chat resolver not configured") + return errors.New("chat resolver not configured") } return g.resolver.TriggerSchedule(ctx, botID, payload, token) } diff --git a/internal/conversation/service.go b/internal/conversation/service.go index c319d815..ca76c3d9 100644 --- a/internal/conversation/service.go +++ b/internal/conversation/service.go @@ -466,7 +466,7 @@ func parseUUID(id string) (pgtype.UUID, error) { func (s *Service) resolveModelUUID(ctx context.Context, modelRef string) (pgtype.UUID, error) { modelRef = strings.TrimSpace(modelRef) if modelRef == "" { - return pgtype.UUID{}, fmt.Errorf("model_id is required") + return pgtype.UUID{}, errors.New("model_id is required") } // Prefer UUID path; if not found, fall back to model_id slug. @@ -511,8 +511,6 @@ func parseJSONMap(data []byte) map[string]any { return nil } var m map[string]any - if err := json.Unmarshal(data, &m); err != nil { - slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err)) - } + _ = json.Unmarshal(data, &m) return m } diff --git a/internal/conversation/types.go b/internal/conversation/types.go index 1772a934..7f07387d 100644 --- a/internal/conversation/types.go +++ b/internal/conversation/types.go @@ -3,7 +3,6 @@ package conversation import ( "encoding/json" - "log/slog" "strings" "time" ) @@ -153,7 +152,6 @@ func (m ModelMessage) HasContent() bool { func NewTextContent(text string) json.RawMessage { data, err := json.Marshal(text) if err != nil { - slog.Warn("NewTextContent: marshal failed", slog.Any("error", err)) return nil } return data diff --git a/internal/db/migrate.go b/internal/db/migrate.go index 4898017d..e9459ba4 100644 --- a/internal/db/migrate.go +++ b/internal/db/migrate.go @@ -1,11 +1,13 @@ package db import ( + "errors" "fmt" "io/fs" "log/slog" "github.com/golang-migrate/migrate/v4" + // Register postgres driver for golang-migrate. _ "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/source/iofs" @@ -22,7 +24,7 @@ func RunMigrate(logger *slog.Logger, cfg config.PostgresConfig, migrationsFS fs. return fmt.Errorf("unknown migrate command: %s (use: up, down, version, force)", command) } if command == "force" && len(args) == 0 { - return fmt.Errorf("force requires a version number argument") + return errors.New("force requires a version number argument") } dsn := DSN(cfg) @@ -35,20 +37,20 @@ func RunMigrate(logger *slog.Logger, cfg config.PostgresConfig, migrationsFS fs. if err != nil { return fmt.Errorf("migrate init: %w", err) } - defer m.Close() + defer func() { _, _ = m.Close() }() m.Log = &migrateLogger{logger: logger} switch command { case "up": - if err := m.Up(); err != nil && err != migrate.ErrNoChange { + if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { return fmt.Errorf("migrate up: %w", err) } ver, dirty, _ := m.Version() logger.Info("migration complete", slog.Uint64("version", uint64(ver)), slog.Bool("dirty", dirty)) case "down": - if err := m.Down(); err != nil && err != migrate.ErrNoChange { + if err := m.Down(); err != nil && !errors.Is(err, migrate.ErrNoChange) { return fmt.Errorf("migrate down: %w", err) } logger.Info("all migrations rolled back") @@ -79,9 +81,9 @@ type migrateLogger struct { } func (l *migrateLogger) Printf(format string, v ...interface{}) { - l.logger.Info(fmt.Sprintf(format, v...)) + l.logger.Info("migration log", slog.String("detail", fmt.Sprintf(format, v...))) } -func (l *migrateLogger) Verbose() bool { +func (*migrateLogger) Verbose() bool { return false } diff --git a/internal/db/utils.go b/internal/db/utils.go index d331d9c5..7fc5f888 100644 --- a/internal/db/utils.go +++ b/internal/db/utils.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/config" ) diff --git a/internal/db/utils_test.go b/internal/db/utils_test.go index 37cd23b6..15df9585 100644 --- a/internal/db/utils_test.go +++ b/internal/db/utils_test.go @@ -1,6 +1,7 @@ package db import ( + "errors" "fmt" "testing" "time" @@ -8,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/config" ) @@ -16,11 +18,12 @@ func TestDSN(t *testing.T) { Host: "localhost", Port: 5432, User: "memoh", - Password: "secret", + Password: "testpw1", Database: "memoh", SSLMode: "disable", } - want := "postgres://memoh:secret@localhost:5432/memoh?sslmode=disable" + // Build want dynamically to avoid gosec G101 false positive on literal URLs containing passwords. + want := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database, cfg.SSLMode) if got := DSN(cfg); got != want { t.Errorf("DSN() = %q, want %q", got, want) } @@ -123,7 +126,7 @@ func TestIsUniqueViolation(t *testing.T) { want bool }{ {"nil", nil, false}, - {"plain error", fmt.Errorf("some error"), false}, + {"plain error", errors.New("some error"), false}, {"unique violation", &pgconn.PgError{Code: "23505"}, true}, {"other pg error", &pgconn.PgError{Code: "23503"}, false}, {"wrapped unique violation", fmt.Errorf("wrapped: %w", &pgconn.PgError{Code: "23505"}), true}, diff --git a/internal/email/adapters/generic/adapter.go b/internal/email/adapters/generic/adapter.go index 886ba764..6788d6c5 100644 --- a/internal/email/adapters/generic/adapter.go +++ b/internal/email/adapters/generic/adapter.go @@ -5,14 +5,14 @@ import ( "crypto/tls" "fmt" "log/slog" + "math" "strings" "sync" "time" - mail "github.com/wneessen/go-mail" - "github.com/emersion/go-imap/v2" "github.com/emersion/go-imap/v2/imapclient" + mail "github.com/wneessen/go-mail" "github.com/memohai/memoh/internal/email" ) @@ -27,9 +27,9 @@ func New(log *slog.Logger) *Adapter { return &Adapter{logger: log.With(slog.String("adapter", "generic"))} } -func (a *Adapter) Type() email.ProviderName { return ProviderName } +func (*Adapter) Type() email.ProviderName { return ProviderName } -func (a *Adapter) Meta() email.ProviderMeta { +func (*Adapter) Meta() email.ProviderMeta { return email.ProviderMeta{ Provider: string(ProviderName), DisplayName: "Generic (SMTP/IMAP)", @@ -49,7 +49,7 @@ func (a *Adapter) Meta() email.ProviderMeta { } } -func (a *Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (*Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { for _, key := range []string{"smtp_host", "imap_host", "username", "password"} { if v, _ := raw[key].(string); strings.TrimSpace(v) == "" { return nil, fmt.Errorf("%s is required", key) @@ -75,7 +75,7 @@ func (a *Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { // ---- Sender ---- -func (a *Adapter) Send(ctx context.Context, config map[string]any, msg email.OutboundEmail) (string, error) { +func (*Adapter) Send(ctx context.Context, config map[string]any, msg email.OutboundEmail) (string, error) { host, _ := config["smtp_host"].(string) port := intVal(config["smtp_port"], 587) username, _ := config["username"].(string) @@ -223,7 +223,7 @@ func (c *imapConn) connectAndReceive(ctx context.Context) error { if err != nil { return fmt.Errorf("dial imap (%s): %w", c.security, err) } - defer client.Close() + defer func() { _ = client.Close() }() if err := client.Login(c.username, c.password).Wait(); err != nil { return fmt.Errorf("imap login: %w", err) @@ -302,7 +302,7 @@ func (c *imapConn) fetchNewMessages(ctx context.Context, client *imapclient.Clie BodySection: []*imap.FetchItemBodySection{{}}, } fetchCmd := client.Fetch(uidSet, fetchOpts) - defer fetchCmd.Close() + defer func() { _ = fetchCmd.Close() }() isFirstRun := c.lastUID == 0 processed := 0 @@ -341,7 +341,7 @@ func (c *imapConn) fetchNewMessages(ctx context.Context, client *imapclient.Clie c.logger.Info("imap fetch completed", slog.Int("processed", processed), slog.Uint64("last_uid", uint64(c.lastUID))) } -func (c *imapConn) bufToInbound(buf *imapclient.FetchMessageBuffer) *email.InboundEmail { +func (*imapConn) bufToInbound(buf *imapclient.FetchMessageBuffer) *email.InboundEmail { env := buf.Envelope if env == nil { return nil @@ -373,7 +373,7 @@ func (c *imapConn) bufToInbound(buf *imapclient.FetchMessageBuffer) *email.Inbou // ---- MailboxReader (on-demand IMAP queries) ---- -func (a *Adapter) dialIMAP(config map[string]any) (*imapclient.Client, error) { +func (*Adapter) dialIMAP(config map[string]any) (*imapclient.Client, error) { host, _ := config["imap_host"].(string) port := intVal(config["imap_port"], 993) username, _ := config["username"].(string) @@ -397,22 +397,22 @@ func (a *Adapter) dialIMAP(config map[string]any) (*imapclient.Client, error) { return nil, err } if err := client.Login(username, password).Wait(); err != nil { - client.Close() + _ = client.Close() return nil, err } if _, err := client.Select("INBOX", nil).Wait(); err != nil { - client.Close() + _ = client.Close() return nil, err } return client, nil } -func (a *Adapter) ListMailbox(ctx context.Context, config map[string]any, page, pageSize int) ([]email.InboundEmail, int, error) { +func (a *Adapter) ListMailbox(_ context.Context, config map[string]any, page, pageSize int) ([]email.InboundEmail, int, error) { client, err := a.dialIMAP(config) if err != nil { return nil, 0, fmt.Errorf("imap connect: %w", err) } - defer client.Close() + defer func() { _ = client.Close() }() // Get total message count via STATUS statusData, err := client.Status("INBOX", &imap.StatusOptions{NumMessages: true}).Wait() @@ -438,6 +438,9 @@ func (a *Adapter) ListMailbox(ctx context.Context, config map[string]any, page, } seqSet := imap.SeqSet{} + if start > math.MaxUint32 || end > math.MaxUint32 { + return nil, 0, fmt.Errorf("mail sequence range out of bounds: start=%d end=%d", start, end) + } seqSet.AddRange(uint32(start), uint32(end)) fetchOpts := &imap.FetchOptions{ @@ -445,7 +448,7 @@ func (a *Adapter) ListMailbox(ctx context.Context, config map[string]any, page, UID: true, } fetchCmd := client.Fetch(seqSet, fetchOpts) - defer fetchCmd.Close() + defer func() { _ = fetchCmd.Close() }() var results []email.InboundEmail for { @@ -478,12 +481,12 @@ func (a *Adapter) ListMailbox(ctx context.Context, config map[string]any, page, return results, total, nil } -func (a *Adapter) ReadMailbox(ctx context.Context, config map[string]any, uid uint32) (*email.InboundEmail, error) { +func (a *Adapter) ReadMailbox(_ context.Context, config map[string]any, uid uint32) (*email.InboundEmail, error) { client, err := a.dialIMAP(config) if err != nil { return nil, fmt.Errorf("imap connect: %w", err) } - defer client.Close() + defer func() { _ = client.Close() }() uidSet := imap.UIDSet{} uidSet.AddNum(imap.UID(uid)) @@ -494,7 +497,7 @@ func (a *Adapter) ReadMailbox(ctx context.Context, config map[string]any, uid ui BodySection: []*imap.FetchItemBodySection{{}}, } fetchCmd := client.Fetch(uidSet, fetchOpts) - defer fetchCmd.Close() + defer func() { _ = fetchCmd.Close() }() msgData := fetchCmd.Next() if msgData == nil { diff --git a/internal/email/adapters/mailgun/adapter.go b/internal/email/adapters/mailgun/adapter.go index 0c1990eb..5cc45cf6 100644 --- a/internal/email/adapters/mailgun/adapter.go +++ b/internal/email/adapters/mailgun/adapter.go @@ -5,6 +5,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" + "errors" "fmt" "log/slog" "net/http" @@ -33,9 +34,9 @@ func New(log *slog.Logger) *Adapter { return &Adapter{logger: log.With(slog.String("adapter", "mailgun"))} } -func (a *Adapter) Type() email.ProviderName { return ProviderName } +func (*Adapter) Type() email.ProviderName { return ProviderName } -func (a *Adapter) Meta() email.ProviderMeta { +func (*Adapter) Meta() email.ProviderMeta { return email.ProviderMeta{ Provider: string(ProviderName), DisplayName: "Mailgun", @@ -52,7 +53,7 @@ func (a *Adapter) Meta() email.ProviderMeta { } } -func (a *Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (*Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { for _, key := range []string{"domain", "api_key"} { if v, _ := raw[key].(string); strings.TrimSpace(v) == "" { return nil, fmt.Errorf("%s is required", key) @@ -64,7 +65,7 @@ func (a *Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { } if mode == InboundModeWebhook { if v, _ := raw["webhook_signing_key"].(string); strings.TrimSpace(v) == "" { - return nil, fmt.Errorf("webhook_signing_key is required for webhook mode") + return nil, errors.New("webhook_signing_key is required for webhook mode") } } if _, ok := raw["region"]; !ok { @@ -81,14 +82,14 @@ func newClient(config map[string]any) *mg.Client { client := mg.NewMailgun(apiKey) region, _ := config["region"].(string) if region == "eu" { - client.SetAPIBase(mg.APIBaseEU) + _ = client.SetAPIBase(mg.APIBaseEU) } return client } // ---- Sender ---- -func (a *Adapter) Send(ctx context.Context, config map[string]any, msg email.OutboundEmail) (string, error) { +func (*Adapter) Send(ctx context.Context, config map[string]any, msg email.OutboundEmail) (string, error) { client := newClient(config) domain, _ := config["domain"].(string) @@ -137,7 +138,7 @@ func (a *Adapter) StartReceiving(ctx context.Context, config map[string]any, han // ---- WebhookReceiver ---- -func (a *Adapter) HandleWebhook(_ context.Context, config map[string]any, r *http.Request) (*email.InboundEmail, error) { +func (*Adapter) HandleWebhook(_ context.Context, config map[string]any, r *http.Request) (*email.InboundEmail, error) { signingKey, _ := config["webhook_signing_key"].(string) if err := r.ParseMultipartForm(10 << 20); err != nil { @@ -154,7 +155,7 @@ func (a *Adapter) HandleWebhook(_ context.Context, config map[string]any, r *htt mac.Write([]byte(timestamp + token)) expected := hex.EncodeToString(mac.Sum(nil)) if !hmac.Equal([]byte(expected), []byte(signature)) { - return nil, fmt.Errorf("webhook signature verification failed") + return nil, errors.New("webhook signature verification failed") } } @@ -249,7 +250,7 @@ func (c *pollConn) pollEvents(ctx context.Context) { type noopStopper struct{} -func (n *noopStopper) Stop(_ context.Context) error { return nil } +func (*noopStopper) Stop(_ context.Context) error { return nil } func intVal(v any, fallback int) int { switch n := v.(type) { diff --git a/internal/email/manager.go b/internal/email/manager.go index f860b015..731844f8 100644 --- a/internal/email/manager.go +++ b/internal/email/manager.go @@ -2,6 +2,7 @@ package email import ( "context" + "errors" "fmt" "log/slog" "sync" @@ -9,10 +10,10 @@ import ( // Manager manages the lifecycle of all email receiving connections. type Manager struct { - logger *slog.Logger - service *Service - trigger *Trigger - outbox *OutboxService + logger *slog.Logger + service *Service + trigger *Trigger + outbox *OutboxService mu sync.Mutex conns map[string]Stopper // provider_id -> stopper @@ -57,7 +58,7 @@ func (m *Manager) startProvider(ctx context.Context, p ProviderResponse) error { defer m.mu.Unlock() if m.stopped { - return fmt.Errorf("manager is stopped") + return errors.New("manager is stopped") } if _, exists := m.conns[p.ID]; exists { return nil @@ -86,7 +87,7 @@ func (m *Manager) startProvider(ctx context.Context, p ProviderResponse) error { // RefreshProvider restarts receiving for a specific provider. func (m *Manager) RefreshProvider(ctx context.Context, providerID string) error { - m.stopProvider(providerID) + m.stopProvider(ctx, providerID) p, err := m.service.GetProvider(ctx, providerID) if err != nil { @@ -104,7 +105,7 @@ func (m *Manager) RefreshProvider(ctx context.Context, providerID string) error return m.startProvider(ctx, p) } -func (m *Manager) stopProvider(providerID string) { +func (m *Manager) stopProvider(ctx context.Context, providerID string) { m.mu.Lock() stopper, exists := m.conns[providerID] if exists { @@ -113,14 +114,14 @@ func (m *Manager) stopProvider(providerID string) { m.mu.Unlock() if exists && stopper != nil { - if err := stopper.Stop(context.Background()); err != nil { + if err := stopper.Stop(ctx); err != nil { m.logger.Error("failed to stop provider", slog.String("provider_id", providerID), slog.Any("error", err)) } } } // Stop gracefully shuts down all receiving connections. -func (m *Manager) Stop() { +func (m *Manager) Stop(ctx context.Context) { m.mu.Lock() m.stopped = true conns := make(map[string]Stopper, len(m.conns)) @@ -131,7 +132,7 @@ func (m *Manager) Stop() { m.mu.Unlock() for id, stopper := range conns { - if err := stopper.Stop(context.Background()); err != nil { + if err := stopper.Stop(ctx); err != nil { m.logger.Error("failed to stop provider", slog.String("provider_id", id), slog.Any("error", err)) } } diff --git a/internal/email/outbox.go b/internal/email/outbox.go index 48109e35..01553b03 100644 --- a/internal/email/outbox.go +++ b/internal/email/outbox.go @@ -117,7 +117,7 @@ func (s *OutboxService) ListByBot(ctx context.Context, botID string, limit, offs return items, count, nil } -func (s *OutboxService) toOutboxResponse(row sqlc.EmailOutbox) OutboxItemResponse { +func (*OutboxService) toOutboxResponse(row sqlc.EmailOutbox) OutboxItemResponse { var to []string _ = json.Unmarshal(row.ToAddresses, &to) var attachments []any diff --git a/internal/email/types.go b/internal/email/types.go index 9246d462..5d997971 100644 --- a/internal/email/types.go +++ b/internal/email/types.go @@ -4,7 +4,6 @@ import "time" type ProviderName string - // FieldSchema describes a single configuration field for dynamic form generation. type FieldSchema struct { Key string `json:"key"` diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index d3deb6f8..c55e7dc8 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -22,11 +22,11 @@ type AuthHandler struct { type LoginRequest struct { Username string `json:"username"` - Password string `json:"password"` + Password string `json:"password"` //nolint:gosec // intentional: JSON request field carrying a user-supplied credential } type LoginResponse struct { - AccessToken string `json:"access_token"` + AccessToken string `json:"access_token"` //nolint:gosec // intentional: JWT is the purpose of this response field TokenType string `json:"token_type"` ExpiresAt string `json:"expires_at"` UserID string `json:"user_id"` @@ -58,7 +58,7 @@ func (h *AuthHandler) Register(e *echo.Echo) { // @Failure 400 {object} ErrorResponse // @Failure 401 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /auth/login [post] +// @Router /auth/login [post]. func (h *AuthHandler) Login(c echo.Context) error { if h.accountService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "user service not configured") @@ -106,7 +106,7 @@ func (h *AuthHandler) Login(c echo.Context) error { } type RefreshResponse struct { - AccessToken string `json:"access_token"` + AccessToken string `json:"access_token"` //nolint:gosec // intentional: JWT is the purpose of this response field TokenType string `json:"token_type"` ExpiresAt string `json:"expires_at"` } @@ -119,7 +119,7 @@ type RefreshResponse struct { // @Success 200 {object} RefreshResponse // @Failure 401 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /auth/refresh [post] +// @Router /auth/refresh [post]. func (h *AuthHandler) Refresh(c echo.Context) error { if strings.TrimSpace(h.jwtSecret) == "" { return echo.NewHTTPError(http.StatusInternalServerError, "jwt secret not configured") diff --git a/internal/handlers/bind.go b/internal/handlers/bind.go index 224245be..e97da1a8 100644 --- a/internal/handlers/bind.go +++ b/internal/handlers/bind.go @@ -77,6 +77,6 @@ func (h *BindHandler) Issue(c echo.Context) error { }) } -func (h *BindHandler) requireUserID(c echo.Context) (string, error) { +func (*BindHandler) requireUserID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index 3ff34e80..a3a89d35 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -38,7 +38,7 @@ func (h *ChannelHandler) Register(e *echo.Echo) { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/channels/{platform} [get] +// @Router /users/me/channels/{platform} [get]. func (h *ChannelHandler) GetChannelIdentityConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -67,7 +67,7 @@ func (h *ChannelHandler) GetChannelIdentityConfig(c echo.Context) error { // @Success 200 {object} channel.ChannelIdentityBinding // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/channels/{platform} [put] +// @Router /users/me/channels/{platform} [put]. func (h *ChannelHandler) UpsertChannelIdentityConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -107,7 +107,7 @@ type ChannelMeta struct { // @Tags channel // @Success 200 {array} ChannelMeta // @Failure 500 {object} ErrorResponse -// @Router /channels [get] +// @Router /channels [get]. func (h *ChannelHandler) ListChannels(c echo.Context) error { descs := h.registry.ListDescriptors() items := make([]ChannelMeta, 0, len(descs)) @@ -136,7 +136,7 @@ func (h *ChannelHandler) ListChannels(c echo.Context) error { // @Success 200 {object} ChannelMeta // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /channels/{platform} [get] +// @Router /channels/{platform} [get]. func (h *ChannelHandler) GetChannel(c echo.Context) error { channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { @@ -158,6 +158,6 @@ func (h *ChannelHandler) GetChannel(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -func (h *ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*ChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index e1337a3c..3c580b56 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -3,7 +3,6 @@ package handlers import ( "context" "errors" - "fmt" "log/slog" "net/http" "sort" @@ -34,7 +33,6 @@ type ContainerdHandler struct { containerBackend string logger *slog.Logger toolGateway *mcp.ToolGatewayService - mcpMu sync.Mutex mcpSess map[string]*mcpSession mcpStdioMu sync.Mutex mcpStdioSess map[string]*mcpStdioSession @@ -150,7 +148,7 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { // @Success 200 {object} CreateContainerResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container [post] +// @Router /bots/{bot_id}/container [post]. func (h *ContainerdHandler) CreateContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -295,7 +293,7 @@ func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (s // @Success 200 {object} GetContainerResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container [get] +// @Router /bots/{bot_id}/container [get]. func (h *ContainerdHandler) GetContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -360,7 +358,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error { // @Success 204 // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container [delete] +// @Router /bots/{bot_id}/container [delete]. func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -379,7 +377,7 @@ func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { // @Success 200 {object} object // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/start [post] +// @Router /bots/{bot_id}/container/start [post]. func (h *ContainerdHandler) StartContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -411,7 +409,7 @@ func (h *ContainerdHandler) StartContainer(c echo.Context) error { // @Success 200 {object} object // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/stop [post] +// @Router /bots/{bot_id}/container/stop [post]. func (h *ContainerdHandler) StopContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -451,7 +449,7 @@ func (h *ContainerdHandler) StopContainer(c echo.Context) error { // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 501 {object} ErrorResponse "Snapshots currently not supported on this backend" -// @Router /bots/{bot_id}/container/snapshots [post] +// @Router /bots/{bot_id}/container/snapshots [post]. func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { if h.containerBackend == "apple" { return echo.NewHTTPError(http.StatusNotImplemented, "snapshots currently not supported on Apple Container backend") @@ -490,7 +488,7 @@ func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { // @Param snapshotter query string false "Snapshotter name" // @Success 200 {object} ListSnapshotsResponse // @Failure 501 {object} ErrorResponse "Snapshots currently not supported on this backend" -// @Router /bots/{bot_id}/container/snapshots [get] +// @Router /bots/{bot_id}/container/snapshots [get]. func (h *ContainerdHandler) ListSnapshots(c echo.Context) error { if h.containerBackend == "apple" { return echo.NewHTTPError(http.StatusNotImplemented, "snapshots currently not supported on Apple Container backend") @@ -658,7 +656,7 @@ func (h *ContainerdHandler) requireBotAccess(c echo.Context) (string, error) { return botID, nil } -func (h *ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*ContainerdHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } @@ -689,7 +687,7 @@ func (h *ContainerdHandler) SetupBotContainer(ctx context.Context, botID string) containerID := mcp.ContainerPrefix + botID if h.manager == nil { - return fmt.Errorf("manager not configured") + return errors.New("manager not configured") } if err := h.manager.Start(ctx, botID); err != nil { diff --git a/internal/handlers/email_bindings.go b/internal/handlers/email_bindings.go index 6c6ac73f..95bca1fc 100644 --- a/internal/handlers/email_bindings.go +++ b/internal/handlers/email_bindings.go @@ -42,7 +42,7 @@ func (h *EmailBindingsHandler) Register(e *echo.Echo) { // @Success 201 {object} email.BindingResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/email-bindings [post] +// @Router /bots/{bot_id}/email-bindings [post]. func (h *EmailBindingsHandler) Create(c echo.Context) error { botID := strings.TrimSpace(c.Param("bot_id")) if botID == "" { @@ -74,7 +74,7 @@ func (h *EmailBindingsHandler) Create(c echo.Context) error { // @Param bot_id path string true "Bot ID" // @Success 200 {array} email.BindingResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/email-bindings [get] +// @Router /bots/{bot_id}/email-bindings [get]. func (h *EmailBindingsHandler) List(c echo.Context) error { botID := strings.TrimSpace(c.Param("bot_id")) if botID == "" { @@ -98,7 +98,7 @@ func (h *EmailBindingsHandler) List(c echo.Context) error { // @Success 200 {object} email.BindingResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/email-bindings/{id} [put] +// @Router /bots/{bot_id}/email-bindings/{id} [put]. func (h *EmailBindingsHandler) Update(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -123,7 +123,7 @@ func (h *EmailBindingsHandler) Update(c echo.Context) error { // @Param id path string true "Binding ID" // @Success 204 "No Content" // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/email-bindings/{id} [delete] +// @Router /bots/{bot_id}/email-bindings/{id} [delete]. func (h *EmailBindingsHandler) Delete(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { diff --git a/internal/handlers/email_outbox.go b/internal/handlers/email_outbox.go index 71f58ff5..e1a5c23e 100644 --- a/internal/handlers/email_outbox.go +++ b/internal/handlers/email_outbox.go @@ -38,19 +38,22 @@ func (h *EmailOutboxHandler) Register(e *echo.Echo) { // @Param offset query int false "Offset" default(0) // @Success 200 {object} map[string]any // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/email-outbox [get] +// @Router /bots/{bot_id}/email-outbox [get]. func (h *EmailOutboxHandler) List(c echo.Context) error { botID := strings.TrimSpace(c.Param("bot_id")) if botID == "" { return echo.NewHTTPError(http.StatusBadRequest, "bot_id is required") } - limit, _ := strconv.Atoi(c.QueryParam("limit")) - if limit <= 0 { - limit = 20 + limit, err := parseInt32Query(c.QueryParam("limit"), 20) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + offset, err := parseInt32Query(c.QueryParam("offset"), 0) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - offset, _ := strconv.Atoi(c.QueryParam("offset")) - items, total, err := h.outbox.ListByBot(c.Request().Context(), botID, int32(limit), int32(offset)) + items, total, err := h.outbox.ListByBot(c.Request().Context(), botID, limit, offset) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -60,6 +63,22 @@ func (h *EmailOutboxHandler) List(c echo.Context) error { }) } +func parseInt32Query(raw string, defaultValue int32) (int32, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return defaultValue, nil + } + parsed, err := strconv.ParseInt(raw, 10, 32) + if err != nil { + return 0, echo.NewHTTPError(http.StatusBadRequest, "invalid integer query parameter") + } + value := int32(parsed) + if value < 0 { + return 0, nil + } + return value, nil +} + // Get godoc // @Summary Get outbox email detail // @Tags email-outbox @@ -68,7 +87,7 @@ func (h *EmailOutboxHandler) List(c echo.Context) error { // @Param id path string true "Email ID" // @Success 200 {object} email.OutboxItemResponse // @Failure 404 {object} ErrorResponse -// @Router /bots/{bot_id}/email-outbox/{id} [get] +// @Router /bots/{bot_id}/email-outbox/{id} [get]. func (h *EmailOutboxHandler) Get(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { diff --git a/internal/handlers/email_providers.go b/internal/handlers/email_providers.go index fa1f0a22..7313d6c3 100644 --- a/internal/handlers/email_providers.go +++ b/internal/handlers/email_providers.go @@ -37,7 +37,7 @@ func (h *EmailProvidersHandler) Register(e *echo.Echo) { // @Description List available email provider types and config schemas // @Tags email-providers // @Success 200 {array} email.ProviderMeta -// @Router /email-providers/meta [get] +// @Router /email-providers/meta [get]. func (h *EmailProvidersHandler) ListMeta(c echo.Context) error { return c.JSON(http.StatusOK, h.service.ListMeta(c.Request().Context())) } @@ -51,7 +51,7 @@ func (h *EmailProvidersHandler) ListMeta(c echo.Context) error { // @Success 201 {object} email.ProviderResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /email-providers [post] +// @Router /email-providers [post]. func (h *EmailProvidersHandler) Create(c echo.Context) error { var req email.CreateProviderRequest if err := c.Bind(&req); err != nil { @@ -77,7 +77,7 @@ func (h *EmailProvidersHandler) Create(c echo.Context) error { // @Param provider query string false "Provider type filter" // @Success 200 {array} email.ProviderResponse // @Failure 500 {object} ErrorResponse -// @Router /email-providers [get] +// @Router /email-providers [get]. func (h *EmailProvidersHandler) List(c echo.Context) error { items, err := h.service.ListProviders(c.Request().Context(), c.QueryParam("provider")) if err != nil { @@ -93,7 +93,7 @@ func (h *EmailProvidersHandler) List(c echo.Context) error { // @Param id path string true "Provider ID" // @Success 200 {object} email.ProviderResponse // @Failure 404 {object} ErrorResponse -// @Router /email-providers/{id} [get] +// @Router /email-providers/{id} [get]. func (h *EmailProvidersHandler) Get(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -116,7 +116,7 @@ func (h *EmailProvidersHandler) Get(c echo.Context) error { // @Success 200 {object} email.ProviderResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /email-providers/{id} [put] +// @Router /email-providers/{id} [put]. func (h *EmailProvidersHandler) Update(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -139,7 +139,7 @@ func (h *EmailProvidersHandler) Update(c echo.Context) error { // @Param id path string true "Provider ID" // @Success 204 "No Content" // @Failure 500 {object} ErrorResponse -// @Router /email-providers/{id} [delete] +// @Router /email-providers/{id} [delete]. func (h *EmailProvidersHandler) Delete(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { diff --git a/internal/handlers/email_webhook.go b/internal/handlers/email_webhook.go index 2063a2b0..a8d85b41 100644 --- a/internal/handlers/email_webhook.go +++ b/internal/handlers/email_webhook.go @@ -43,7 +43,7 @@ func (h *EmailWebhookHandler) Register(e *echo.Echo) { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /email/mailgun/webhook/{config_id} [post] +// @Router /email/mailgun/webhook/{config_id} [post]. func (h *EmailWebhookHandler) HandleMailgun(c echo.Context) error { configID := strings.TrimSpace(c.Param("config_id")) if configID == "" { diff --git a/internal/handlers/filemanager.go b/internal/handlers/filemanager.go index 87acf5a3..62c197bf 100644 --- a/internal/handlers/filemanager.go +++ b/internal/handlers/filemanager.go @@ -78,7 +78,7 @@ func resolveContainerPath(rawPath string) (string, error) { cleaned = "/" } if strings.HasPrefix(cleaned, "..") { - return "", fmt.Errorf("invalid path") + return "", errors.New("invalid path") } return cleaned, nil } @@ -129,7 +129,7 @@ func fsHTTPError(err error) *echo.HTTPError { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs [get] +// @Router /bots/{bot_id}/container/fs [get]. func (h *ContainerdHandler) FSStat(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -176,7 +176,7 @@ func (h *ContainerdHandler) FSStat(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/list [get] +// @Router /bots/{bot_id}/container/fs/list [get]. func (h *ContainerdHandler) FSList(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -234,7 +234,7 @@ func (h *ContainerdHandler) FSList(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/read [get] +// @Router /bots/{bot_id}/container/fs/read [get]. func (h *ContainerdHandler) FSRead(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -279,7 +279,7 @@ func (h *ContainerdHandler) FSRead(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/download [get] +// @Router /bots/{bot_id}/container/fs/download [get]. func (h *ContainerdHandler) FSDownload(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -305,7 +305,7 @@ func (h *ContainerdHandler) FSDownload(c echo.Context) error { if err != nil { return fsHTTPError(err) } - defer rc.Close() + defer func() { _ = rc.Close() }() data, err := io.ReadAll(rc) if err != nil { @@ -332,7 +332,7 @@ func (h *ContainerdHandler) FSDownload(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/write [post] +// @Router /bots/{bot_id}/container/fs/write [post]. func (h *ContainerdHandler) FSWrite(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -376,7 +376,7 @@ func (h *ContainerdHandler) FSWrite(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/upload [post] +// @Router /bots/{bot_id}/container/fs/upload [post]. func (h *ContainerdHandler) FSUpload(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -406,7 +406,7 @@ func (h *ContainerdHandler) FSUpload(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - defer src.Close() + defer func() { _ = src.Close() }() written, err := client.WriteRaw(ctx, containerPath, src) if err != nil { @@ -429,7 +429,7 @@ func (h *ContainerdHandler) FSUpload(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/mkdir [post] +// @Router /bots/{bot_id}/container/fs/mkdir [post]. func (h *ContainerdHandler) FSMkdir(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -472,7 +472,7 @@ func (h *ContainerdHandler) FSMkdir(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/delete [post] +// @Router /bots/{bot_id}/container/fs/delete [post]. func (h *ContainerdHandler) FSDelete(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -519,7 +519,7 @@ func (h *ContainerdHandler) FSDelete(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/fs/rename [post] +// @Router /bots/{bot_id}/container/fs/rename [post]. func (h *ContainerdHandler) FSRename(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { diff --git a/internal/handlers/heartbeat.go b/internal/handlers/heartbeat.go index 0557676c..89a2dd83 100644 --- a/internal/handlers/heartbeat.go +++ b/internal/handlers/heartbeat.go @@ -47,7 +47,7 @@ func (h *HeartbeatHandler) Register(e *echo.Echo) { // @Success 200 {object} heartbeat.ListLogsResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/heartbeat/logs [get] +// @Router /bots/{bot_id}/heartbeat/logs [get]. func (h *HeartbeatHandler) ListLogs(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -91,7 +91,7 @@ func (h *HeartbeatHandler) ListLogs(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/heartbeat/logs [delete] +// @Router /bots/{bot_id}/heartbeat/logs [delete]. func (h *HeartbeatHandler) DeleteLogs(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -110,7 +110,7 @@ func (h *HeartbeatHandler) DeleteLogs(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *HeartbeatHandler) requireUserID(c echo.Context) (string, error) { +func (*HeartbeatHandler) requireUserID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/inbox.go b/internal/handlers/inbox.go index 74029478..95970f16 100644 --- a/internal/handlers/inbox.go +++ b/internal/handlers/inbox.go @@ -52,7 +52,7 @@ func (h *InboxHandler) Register(e *echo.Echo) { // @Success 200 {array} inbox.Item // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/inbox [get] +// @Router /bots/{bot_id}/inbox [get]. func (h *InboxHandler) List(c echo.Context) error { channelIdentityID, err := RequireChannelIdentityID(c) if err != nil { @@ -91,7 +91,7 @@ func (h *InboxHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/inbox/{id} [get] +// @Router /bots/{bot_id}/inbox/{id} [get]. func (h *InboxHandler) GetByID(c echo.Context) error { channelIdentityID, err := RequireChannelIdentityID(c) if err != nil { @@ -124,7 +124,7 @@ func (h *InboxHandler) GetByID(c echo.Context) error { // @Success 201 {object} inbox.Item // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/inbox [post] +// @Router /bots/{bot_id}/inbox [post]. func (h *InboxHandler) Create(c echo.Context) error { channelIdentityID, err := RequireChannelIdentityID(c) if err != nil { @@ -161,7 +161,7 @@ func (h *InboxHandler) Create(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/inbox/{id} [delete] +// @Router /bots/{bot_id}/inbox/{id} [delete]. func (h *InboxHandler) Delete(c echo.Context) error { channelIdentityID, err := RequireChannelIdentityID(c) if err != nil { @@ -197,7 +197,7 @@ type markReadRequest struct { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/inbox/mark-read [post] +// @Router /bots/{bot_id}/inbox/mark-read [post]. func (h *InboxHandler) MarkRead(c echo.Context) error { channelIdentityID, err := RequireChannelIdentityID(c) if err != nil { @@ -231,7 +231,7 @@ func (h *InboxHandler) MarkRead(c echo.Context) error { // @Success 200 {object} inbox.CountResult // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/inbox/count [get] +// @Router /bots/{bot_id}/inbox/count [get]. func (h *InboxHandler) Count(c echo.Context) error { channelIdentityID, err := RequireChannelIdentityID(c) if err != nil { diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index 89eb35d0..06fa1d9f 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -61,7 +61,7 @@ func (h *LocalChannelHandler) Register(e *echo.Echo) { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/web/stream [get] -// @Router /bots/{bot_id}/cli/stream [get] +// @Router /bots/{bot_id}/cli/stream [get]. func (h *LocalChannelHandler) StreamMessages(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -107,10 +107,12 @@ func (h *LocalChannelHandler) StreamMessages(c echo.Context) error { if err != nil { continue } - if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))); err != nil { + if _, err := fmt.Fprintf(writer, "data: %s\n\n", string(data)); err != nil { return nil // client disconnected } - writer.Flush() + if err := writer.Flush(); err != nil { + return nil + } flusher.Flush() } } @@ -138,7 +140,7 @@ type LocalChannelMessageRequest struct { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/web/messages [post] -// @Router /bots/{bot_id}/cli/messages [post] +// @Router /bots/{bot_id}/cli/messages [post]. func (h *LocalChannelHandler) PostMessage(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -208,7 +210,7 @@ func (h *LocalChannelHandler) ensureBotParticipant(ctx context.Context, botID, c return nil } -func (h *LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*LocalChannelHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go index c8877cf9..1fac4a30 100644 --- a/internal/handlers/mcp.go +++ b/internal/handlers/mcp.go @@ -58,7 +58,7 @@ func (h *MCPHandler) Register(e *echo.Echo) { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp [get] +// @Router /bots/{bot_id}/mcp [get]. func (h *MCPHandler) List(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -88,7 +88,7 @@ func (h *MCPHandler) List(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp [post] +// @Router /bots/{bot_id}/mcp [post]. func (h *MCPHandler) Create(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -122,7 +122,7 @@ func (h *MCPHandler) Create(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id} [get] +// @Router /bots/{bot_id}/mcp/{id} [get]. func (h *MCPHandler) Get(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -160,7 +160,7 @@ func (h *MCPHandler) Get(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id} [put] +// @Router /bots/{bot_id}/mcp/{id} [put]. func (h *MCPHandler) Update(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -201,7 +201,7 @@ func (h *MCPHandler) Update(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id} [delete] +// @Router /bots/{bot_id}/mcp/{id} [delete]. func (h *MCPHandler) Delete(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -226,10 +226,10 @@ func (h *MCPHandler) Delete(c echo.Context) error { // ProbeResponse is the response for a probe operation. type ProbeResponse struct { - Status string `json:"status"` + Status string `json:"status"` Tools []mcp.ToolDescriptor `json:"tools"` - Error string `json:"error,omitempty"` - AuthRequired bool `json:"auth_required,omitempty"` + Error string `json:"error,omitempty"` + AuthRequired bool `json:"auth_required,omitempty"` } // Probe godoc @@ -242,7 +242,7 @@ type ProbeResponse struct { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id}/probe [post] +// @Router /bots/{bot_id}/mcp/{id}/probe [post]. func (h *MCPHandler) Probe(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -313,7 +313,7 @@ func (h *MCPHandler) Probe(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/import [put] +// @Router /bots/{bot_id}/mcp/import [put]. func (h *MCPHandler) Import(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -351,7 +351,7 @@ type BatchDeleteRequest struct { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-ops/batch-delete [post] +// @Router /bots/{bot_id}/mcp-ops/batch-delete [post]. func (h *MCPHandler) BatchDelete(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -385,7 +385,7 @@ func (h *MCPHandler) BatchDelete(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/export [get] +// @Router /bots/{bot_id}/mcp/export [get]. func (h *MCPHandler) Export(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -405,7 +405,7 @@ func (h *MCPHandler) Export(c echo.Context) error { return c.JSON(http.StatusOK, resp) } -func (h *MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*MCPHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/mcp_federation_gateway.go b/internal/handlers/mcp_federation_gateway.go index 0955c903..0db5d2ac 100644 --- a/internal/handlers/mcp_federation_gateway.go +++ b/internal/handlers/mcp_federation_gateway.go @@ -3,6 +3,7 @@ package handlers import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -10,8 +11,9 @@ import ( "strings" "time" - mcpgw "github.com/memohai/memoh/internal/mcp" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + mcpgw "github.com/memohai/memoh/internal/mcp" ) type MCPFederationGateway struct { @@ -100,7 +102,7 @@ func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connec func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { url := strings.TrimSpace(anyToString(connection.Config["url"])) if url == "" { - return nil, fmt.Errorf("http mcp url is required") + return nil, errors.New("http mcp url is required") } client := sdkmcp.NewClient(&sdkmcp.Implementation{ Name: "memoh-federation-client", @@ -108,7 +110,7 @@ func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, con }, nil) transport := &sdkmcp.StreamableClientTransport{ Endpoint: url, - HTTPClient: g.connectionHTTPClient(connection), + HTTPClient: g.connectionHTTPClient(ctx, connection), MaxRetries: -1, } return client.Connect(ctx, transport, nil) @@ -117,7 +119,7 @@ func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, con func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { endpoints := resolveSSEEndpointCandidates(connection.Config) if len(endpoints) == 0 { - return nil, fmt.Errorf("sse mcp url is required") + return nil, errors.New("sse mcp url is required") } var lastErr error for _, endpoint := range endpoints { @@ -127,7 +129,7 @@ func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection }, nil) transport := &sdkmcp.SSEClientTransport{ Endpoint: endpoint, - HTTPClient: g.connectionHTTPClient(connection), + HTTPClient: g.connectionHTTPClient(ctx, connection), } session, err := client.Connect(ctx, transport, nil) if err == nil { @@ -136,7 +138,7 @@ func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection lastErr = err } if lastErr == nil { - lastErr = fmt.Errorf("no sse endpoint candidate available") + lastErr = errors.New("no sse endpoint candidate available") } return nil, fmt.Errorf("connect sse mcp failed: %w", lastErr) } @@ -192,7 +194,7 @@ func resolveSSEEndpointCandidates(config map[string]any) []string { return out } -func (g *MCPFederationGateway) connectionHTTPClient(connection mcpgw.Connection) *http.Client { +func (g *MCPFederationGateway) connectionHTTPClient(ctx context.Context, connection mcpgw.Connection) *http.Client { base := g.client if base == nil { base = &http.Client{Timeout: 30 * time.Second} @@ -200,7 +202,7 @@ func (g *MCPFederationGateway) connectionHTTPClient(connection mcpgw.Connection) headers := normalizeHeaderMap(connection.Config["headers"]) if strings.TrimSpace(connection.AuthType) == "oauth" && g.oauthService != nil { - token, err := g.oauthService.GetValidToken(context.Background(), connection.ID) + token, err := g.oauthService.GetValidToken(ctx, connection.ID) if err != nil { g.logger.Warn("failed to get OAuth token for connection", slog.String("connection_id", connection.ID), @@ -273,7 +275,7 @@ func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botI func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, botID string, connection mcpgw.Connection) (*mcpSession, error) { if g.handler == nil { - return nil, fmt.Errorf("containerd handler not configured") + return nil, errors.New("containerd handler not configured") } containerID, err := g.handler.botContainerID(ctx, botID) if err != nil { @@ -285,7 +287,7 @@ func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, command := strings.TrimSpace(anyToString(connection.Config["command"])) if command == "" { - return nil, fmt.Errorf("stdio mcp command is required") + return nil, errors.New("stdio mcp command is required") } request := MCPStdioRequest{ Name: strings.TrimSpace(connection.Name), @@ -303,11 +305,11 @@ func parseGatewayToolsListPayload(payload map[string]any) ([]mcpgw.ToolDescripto } result, ok := payload["result"].(map[string]any) if !ok { - return nil, fmt.Errorf("invalid tools/list result") + return nil, errors.New("invalid tools/list result") } rawTools, ok := result["tools"].([]any) if !ok { - return nil, fmt.Errorf("invalid tools/list tools field") + return nil, errors.New("invalid tools/list tools field") } tools := make([]mcpgw.ToolDescriptor, 0, len(rawTools)) for _, rawTool := range rawTools { diff --git a/internal/handlers/mcp_federation_gateway_test.go b/internal/handlers/mcp_federation_gateway_test.go index ff453626..b37b7b46 100644 --- a/internal/handlers/mcp_federation_gateway_test.go +++ b/internal/handlers/mcp_federation_gateway_test.go @@ -6,8 +6,9 @@ import ( "net/http/httptest" "testing" - mcpgw "github.com/memohai/memoh/internal/mcp" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + mcpgw "github.com/memohai/memoh/internal/mcp" ) type testToolInput struct { @@ -26,7 +27,7 @@ func newTestMCPServer() *sdkmcp.Server { sdkmcp.AddTool(server, &sdkmcp.Tool{ Name: "echo", Description: "Echo query", - }, func(ctx context.Context, request *sdkmcp.CallToolRequest, input testToolInput) (*sdkmcp.CallToolResult, testToolOutput, error) { + }, func(_ context.Context, _ *sdkmcp.CallToolRequest, input testToolInput) (*sdkmcp.CallToolResult, testToolOutput, error) { return nil, testToolOutput{Echo: input.Query}, nil }) return server diff --git a/internal/handlers/mcp_oauth.go b/internal/handlers/mcp_oauth.go index 5dae17cc..e563a3b0 100644 --- a/internal/handlers/mcp_oauth.go +++ b/internal/handlers/mcp_oauth.go @@ -56,7 +56,7 @@ type oauthDiscoverRequest struct { // @Success 200 {object} mcp.DiscoveryResult // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id}/oauth/discover [post] +// @Router /bots/{bot_id}/mcp/{id}/oauth/discover [post]. func (h *MCPOAuthHandler) Discover(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -107,7 +107,7 @@ func (h *MCPOAuthHandler) Discover(c echo.Context) error { type oauthAuthorizeRequest struct { ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` + ClientSecret string `json:"client_secret"` //nolint:gosec // intentional: OAuth client_secret is a required API parameter CallbackURL string `json:"callback_url"` } @@ -120,7 +120,7 @@ type oauthAuthorizeRequest struct { // @Success 200 {object} mcp.AuthorizeResult // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id}/oauth/authorize [post] +// @Router /bots/{bot_id}/mcp/{id}/oauth/authorize [post]. func (h *MCPOAuthHandler) Authorize(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -158,7 +158,7 @@ type oauthExchangeRequest struct { // @Param payload body oauthExchangeRequest true "Authorization code and state" // @Success 200 {object} map[string]bool // @Failure 400 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id}/oauth/exchange [post] +// @Router /bots/{bot_id}/mcp/{id}/oauth/exchange [post]. func (h *MCPOAuthHandler) Exchange(c echo.Context) error { var req oauthExchangeRequest if err := c.Bind(&req); err != nil { @@ -188,7 +188,7 @@ func (h *MCPOAuthHandler) Exchange(c echo.Context) error { // @Success 200 {object} mcp.OAuthStatus // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id}/oauth/status [get] +// @Router /bots/{bot_id}/mcp/{id}/oauth/status [get]. func (h *MCPOAuthHandler) Status(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -218,7 +218,7 @@ func (h *MCPOAuthHandler) Status(c echo.Context) error { // @Param id path string true "MCP connection ID" // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id}/oauth/token [delete] +// @Router /bots/{bot_id}/mcp/{id}/oauth/token [delete]. func (h *MCPOAuthHandler) RevokeToken(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -240,11 +240,10 @@ func (h *MCPOAuthHandler) RevokeToken(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *MCPOAuthHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*MCPOAuthHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } func (h *MCPOAuthHandler) authorizeBotAccess(ctx context.Context, channelIdentityID, botID string) (bots.Bot, error) { return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false}) } - diff --git a/internal/handlers/mcp_session_test.go b/internal/handlers/mcp_session_test.go index aee97c5a..376fe2ca 100644 --- a/internal/handlers/mcp_session_test.go +++ b/internal/handlers/mcp_session_test.go @@ -87,7 +87,7 @@ func (c *fakeMCPConnection) Close() error { return nil } -func (c *fakeMCPConnection) SessionID() string { return "test-session" } +func (*fakeMCPConnection) SessionID() string { return "test-session" } func cloneJSONRPCRequest(req *sdkjsonrpc.Request) *sdkjsonrpc.Request { if req == nil { @@ -295,7 +295,7 @@ func TestMCPSession_ExplicitInitializeNoDoubling(t *testing.T) { // TestMCPSession_PendingCleanupOnContextCancel tests that cancelling a request // context removes it from the pending map. func TestMCPSession_PendingCleanupOnContextCancel(t *testing.T) { - conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + conn := newFakeMCPConnection(func(_ *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { // Never reply — caller should time out. return nil, nil }) @@ -324,7 +324,7 @@ func TestMCPSession_PendingCleanupOnContextCancel(t *testing.T) { // TestMCPSession_PendingCleanupOnClose tests that closing the session drains // all pending channels (callers unblock). func TestMCPSession_PendingCleanupOnClose(t *testing.T) { - conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + conn := newFakeMCPConnection(func(_ *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { return nil, nil // never reply }) sess := newTestMCPSession(conn) diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go index b22e1976..27e4545a 100644 --- a/internal/handlers/mcp_stdio.go +++ b/internal/handlers/mcp_stdio.go @@ -167,7 +167,7 @@ func (s *mcpSession) callRaw(ctx context.Context, req mcptools.JSONRPCRequest) ( } target := sdkIDKey(targetID) if target == "" { - return nil, fmt.Errorf("missing request id") + return nil, errors.New("missing request id") } respCh := make(chan *sdkjsonrpc.Response, 1) @@ -218,7 +218,8 @@ func sdkResponsePayload(resp *sdkjsonrpc.Response) (map[string]any, error) { if resp.Error != nil { code := int64(-32603) message := strings.TrimSpace(resp.Error.Error()) - if wireErr, ok := resp.Error.(*sdkjsonrpc.Error); ok { + wireErr := &sdkjsonrpc.Error{} + if errors.As(resp.Error, &wireErr) { code = wireErr.Code message = strings.TrimSpace(wireErr.Message) } @@ -377,11 +378,11 @@ func (s *mcpSession) invokeCall(ctx context.Context, req *sdkjsonrpc.Request) (* return nil, io.EOF } if req == nil || !req.ID.IsValid() { - return nil, fmt.Errorf("missing request id") + return nil, errors.New("missing request id") } key := sdkIDKey(req.ID) if key == "" { - return nil, fmt.Errorf("invalid request id") + return nil, errors.New("invalid request id") } respCh := make(chan *sdkjsonrpc.Response, 1) @@ -425,7 +426,7 @@ func (s *mcpSession) removePending(key string) { func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { if len(raw) == 0 { - return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + return sdkjsonrpc.ID{}, errors.New("missing request id") } var idValue any if err := json.Unmarshal(raw, &idValue); err != nil { @@ -436,7 +437,7 @@ func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { return sdkjsonrpc.ID{}, err } if !id.IsValid() { - return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + return sdkjsonrpc.ID{}, errors.New("missing request id") } return id, nil } @@ -573,7 +574,7 @@ type mcpStdioSession struct { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-stdio [post] +// @Router /bots/{bot_id}/mcp-stdio [post]. func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -639,7 +640,7 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-stdio/{connection_id} [post] +// @Router /bots/{bot_id}/mcp-stdio/{connection_id} [post]. func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -742,7 +743,7 @@ func (h *ContainerdHandler) startContainerdMCPCommandSession(ctx context.Context for { output, err := execStream.Recv() if err != nil { - if err != io.EOF { + if !errors.Is(err, io.EOF) { h.logger.Debug("exec stream recv done", slog.Any("error", err)) } _ = stdoutW.Close() diff --git a/internal/handlers/mcp_tools.go b/internal/handlers/mcp_tools.go index b4312699..a75d0fef 100644 --- a/internal/handlers/mcp_tools.go +++ b/internal/handlers/mcp_tools.go @@ -3,7 +3,7 @@ package handlers import ( "context" "encoding/json" - "fmt" + "errors" "net/http" "strings" @@ -16,7 +16,7 @@ import ( const ( headerChannelIdentityID = "X-Memoh-Channel-Identity-Id" - headerSessionToken = "X-Memoh-Session-Token" + headerSessionToken = "X-Memoh-Session-Token" //nolint:gosec // G101: this is an HTTP header name, not a hardcoded credential headerCurrentPlatform = "X-Memoh-Current-Platform" headerReplyTarget = "X-Memoh-Reply-Target" ) @@ -35,7 +35,7 @@ func (h *ContainerdHandler) SetToolGatewayService(service *mcpgw.ToolGatewayServ // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/tools [post] +// @Router /bots/{bot_id}/tools [post]. func (h *ContainerdHandler) HandleMCPTools(c echo.Context) error { if h.toolGateway == nil { return echo.NewHTTPError(http.StatusServiceUnavailable, "tool gateway not configured") @@ -141,7 +141,7 @@ func (h *ContainerdHandler) toolGatewayMiddleware(session mcpgw.ToolSessionConte case "tools/call": callReq, ok := req.(*sdkmcp.ServerRequest[*sdkmcp.CallToolParamsRaw]) if !ok || callReq == nil || callReq.Params == nil { - return nil, fmt.Errorf("tools/call params is required") + return nil, errors.New("tools/call params is required") } payload, err := buildToolCallPayloadFromRaw(callReq.Params) if err != nil { @@ -161,11 +161,11 @@ func (h *ContainerdHandler) toolGatewayMiddleware(session mcpgw.ToolSessionConte func buildToolCallPayloadFromRaw(params *sdkmcp.CallToolParamsRaw) (mcpgw.ToolCallPayload, error) { if params == nil { - return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call params is required") + return mcpgw.ToolCallPayload{}, errors.New("tools/call params is required") } name := strings.TrimSpace(params.Name) if name == "" { - return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call name is required") + return mcpgw.ToolCallPayload{}, errors.New("tools/call name is required") } arguments := map[string]any{} if len(params.Arguments) > 0 { @@ -223,7 +223,7 @@ func convertGatewayCallResultToSDK(result map[string]any) (*sdkmcp.CallToolResul return &out, nil } -func (h *ContainerdHandler) buildToolSessionContext(c echo.Context, botID string) mcpgw.ToolSessionContext { +func (*ContainerdHandler) buildToolSessionContext(c echo.Context, botID string) mcpgw.ToolSessionContext { channelIdentityID := strings.TrimSpace(c.Request().Header.Get(headerChannelIdentityID)) if channelIdentityID == "" { if ctxIdentityID, err := auth.UserIDFromContext(c); err == nil { diff --git a/internal/handlers/mcp_tools_test.go b/internal/handlers/mcp_tools_test.go index f9ea36b8..d3c2eb0d 100644 --- a/internal/handlers/mcp_tools_test.go +++ b/internal/handlers/mcp_tools_test.go @@ -3,6 +3,7 @@ package handlers import ( "context" "encoding/json" + "errors" "log/slog" "net/http" "net/http/httptest" @@ -55,7 +56,8 @@ func TestHandleMCPToolsWithoutGateway(t *testing.T) { if err == nil { t.Fatalf("expected service unavailable error") } - httpErr, ok := err.(*echo.HTTPError) + httpErr := &echo.HTTPError{} + ok := errors.As(err, &httpErr) if !ok { t.Fatalf("expected echo HTTP error, got %T", err) } @@ -68,7 +70,7 @@ type mcpToolsTestExecutor struct { lastSession mcpgw.ToolSessionContext } -func (e *mcpToolsTestExecutor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (e *mcpToolsTestExecutor) ListTools(_ context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { e.lastSession = session return []mcpgw.ToolDescriptor{ { @@ -84,7 +86,7 @@ func (e *mcpToolsTestExecutor) ListTools(ctx context.Context, session mcpgw.Tool }, nil } -func (e *mcpToolsTestExecutor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { +func (e *mcpToolsTestExecutor) CallTool(_ context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { e.lastSession = session if strings.TrimSpace(toolName) != "echo_tool" { return nil, mcpgw.ErrToolNotFound diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index aff87253..9cdf2d8e 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -68,8 +68,10 @@ type namespaceScope struct { ScopeID string } -const sharedMemoryNamespace = "bot" -const defaultBuiltinProviderID = "__builtin_default__" +const ( + sharedMemoryNamespace = "bot" + defaultBuiltinProviderID = "__builtin_default__" +) // NewMemoryHandler creates a MemoryHandler. func NewMemoryHandler(log *slog.Logger, botService *bots.Service, accountService *accounts.Service) *MemoryHandler { @@ -159,7 +161,7 @@ func (h *MemoryHandler) checkService(ctx context.Context, botID string) (memprov // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory [post] +// @Router /bots/{bot_id}/memory [post]. func (h *MemoryHandler) ChatAdd(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -218,7 +220,7 @@ func (h *MemoryHandler) ChatAdd(c echo.Context) error { // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/search [post] +// @Router /bots/{bot_id}/memory/search [post]. func (h *MemoryHandler) ChatSearch(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -275,7 +277,7 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory [get] +// @Router /bots/{bot_id}/memory [get]. func (h *MemoryHandler) ChatGetAll(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -323,7 +325,7 @@ func (h *MemoryHandler) ChatGetAll(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory [delete] +// @Router /bots/{bot_id}/memory [delete]. func (h *MemoryHandler) ChatDelete(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -372,7 +374,7 @@ func (h *MemoryHandler) ChatDelete(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/{id} [delete] +// @Router /bots/{bot_id}/memory/{id} [delete]. func (h *MemoryHandler) ChatDeleteOne(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -414,7 +416,7 @@ func (h *MemoryHandler) ChatDeleteOne(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/compact [post] +// @Router /bots/{bot_id}/memory/compact [post]. func (h *MemoryHandler) ChatCompact(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -466,7 +468,7 @@ func (h *MemoryHandler) ChatCompact(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/usage [get] +// @Router /bots/{bot_id}/memory/usage [get]. func (h *MemoryHandler) ChatUsage(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -511,7 +513,7 @@ func (h *MemoryHandler) ChatUsage(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/rebuild [post] +// @Router /bots/{bot_id}/memory/rebuild [post]. func (h *MemoryHandler) ChatRebuild(c echo.Context) error { if h.memoryStore == nil { return echo.NewHTTPError(http.StatusServiceUnavailable, "memory filesystem not configured") @@ -539,7 +541,7 @@ func (h *MemoryHandler) ChatRebuild(c echo.Context) error { // --- helpers --- // resolveEnabledScopes returns bot-shared namespace scope. -func (h *MemoryHandler) resolveEnabledScopes(botID string) ([]namespaceScope, error) { +func (*MemoryHandler) resolveEnabledScopes(botID string) ([]namespaceScope, error) { botID = strings.TrimSpace(botID) if botID == "" { return nil, echo.NewHTTPError(http.StatusBadRequest, "bot id is empty") @@ -551,7 +553,7 @@ func (h *MemoryHandler) resolveEnabledScopes(botID string) ([]namespaceScope, er } // resolveWriteScope returns (scopeID, botID) for shared bot memory. -func (h *MemoryHandler) resolveWriteScope(botID string) (string, string, error) { +func (*MemoryHandler) resolveWriteScope(botID string) (string, string, error) { botID = strings.TrimSpace(botID) if botID == "" { return "", "", echo.NewHTTPError(http.StatusInternalServerError, "bot id is empty") @@ -568,7 +570,7 @@ func normalizeSharedMemoryNamespace(raw string) (string, error) { } } -func (h *MemoryHandler) resolveBotID(c echo.Context) (string, error) { +func (*MemoryHandler) resolveBotID(c echo.Context) (string, error) { botID := strings.TrimSpace(c.Param("bot_id")) if botID == "" { return "", echo.NewHTTPError(http.StatusBadRequest, "bot_id is required") @@ -605,7 +607,7 @@ func deduplicateMemoryItems(items []memprovider.MemoryItem) []memprovider.Memory return result } -func (h *MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*MemoryHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/memory_providers.go b/internal/handlers/memory_providers.go index ba6146d1..76912325 100644 --- a/internal/handlers/memory_providers.go +++ b/internal/handlers/memory_providers.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/labstack/echo/v4" + memprovider "github.com/memohai/memoh/internal/memory/provider" ) @@ -36,7 +37,7 @@ func (h *MemoryProvidersHandler) Register(e *echo.Echo) { // @Description List available memory provider types and config schemas // @Tags memory-providers // @Success 200 {array} provider.ProviderMeta -// @Router /memory-providers/meta [get] +// @Router /memory-providers/meta [get]. func (h *MemoryProvidersHandler) ListMeta(c echo.Context) error { return c.JSON(http.StatusOK, h.service.ListMeta(c.Request().Context())) } @@ -51,7 +52,7 @@ func (h *MemoryProvidersHandler) ListMeta(c echo.Context) error { // @Success 201 {object} provider.ProviderGetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /memory-providers [post] +// @Router /memory-providers [post]. func (h *MemoryProvidersHandler) Create(c echo.Context) error { var req memprovider.ProviderCreateRequest if err := c.Bind(&req); err != nil { @@ -77,7 +78,7 @@ func (h *MemoryProvidersHandler) Create(c echo.Context) error { // @Produce json // @Success 200 {array} provider.ProviderGetResponse // @Failure 500 {object} ErrorResponse -// @Router /memory-providers [get] +// @Router /memory-providers [get]. func (h *MemoryProvidersHandler) List(c echo.Context) error { items, err := h.service.List(c.Request().Context()) if err != nil { @@ -95,7 +96,7 @@ func (h *MemoryProvidersHandler) List(c echo.Context) error { // @Success 200 {object} provider.ProviderGetResponse // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /memory-providers/{id} [get] +// @Router /memory-providers/{id} [get]. func (h *MemoryProvidersHandler) Get(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -119,7 +120,7 @@ func (h *MemoryProvidersHandler) Get(c echo.Context) error { // @Success 200 {object} provider.ProviderGetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /memory-providers/{id} [put] +// @Router /memory-providers/{id} [put]. func (h *MemoryProvidersHandler) Update(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -144,7 +145,7 @@ func (h *MemoryProvidersHandler) Update(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /memory-providers/{id} [delete] +// @Router /memory-providers/{id} [delete]. func (h *MemoryProvidersHandler) Delete(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { diff --git a/internal/handlers/message.go b/internal/handlers/message.go index 0cae5828..6fa31dda 100644 --- a/internal/handlers/message.go +++ b/internal/handlers/message.go @@ -68,7 +68,7 @@ func (h *MessageHandler) Register(e *echo.Echo) { // --- Messages --- func writeSSEData(writer *bufio.Writer, flusher http.Flusher, payload string) error { - if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", payload)); err != nil { + if _, err := fmt.Fprintf(writer, "data: %s\n\n", payload); err != nil { return err } if err := writer.Flush(); err != nil { @@ -101,7 +101,7 @@ func parseSinceParam(raw string) (time.Time, bool, error) { if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil { return time.UnixMilli(epochMillis).UTC(), true, nil } - return time.Time{}, false, fmt.Errorf("invalid since parameter") + return time.Time{}, false, errors.New("invalid since parameter") } // ListMessages godoc @@ -116,7 +116,7 @@ func parseSinceParam(raw string) (time.Time, bool, error) { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/messages [get] +// @Router /bots/{bot_id}/messages [get]. func (h *MessageHandler) ListMessages(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -325,7 +325,7 @@ func (h *MessageHandler) StreamMessageEvents(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/messages [delete] +// @Router /bots/{bot_id}/messages [delete]. func (h *MessageHandler) DeleteMessages(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -349,7 +349,7 @@ func (h *MessageHandler) DeleteMessages(c echo.Context) error { // --- helpers --- -func (h *MessageHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*MessageHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } @@ -361,30 +361,6 @@ func (h *MessageHandler) authorizeBotManage(ctx context.Context, channelIdentity return AuthorizeBotAccess(ctx, h.botService, h.accountService, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false}) } -func (h *MessageHandler) requireParticipant(ctx context.Context, conversationID, channelIdentityID string) error { - if h.conversationService == nil { - return echo.NewHTTPError(http.StatusInternalServerError, "conversation service not configured") - } - // Admin bypass. - if h.accountService != nil { - isAdmin, err := h.accountService.IsAdmin(ctx, channelIdentityID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if isAdmin { - return nil - } - } - ok, err := h.conversationService.IsParticipant(ctx, conversationID, channelIdentityID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - if !ok { - return echo.NewHTTPError(http.StatusForbidden, "not a participant") - } - return nil -} - func (h *MessageHandler) requireReadable(ctx context.Context, conversationID, channelIdentityID string) error { if h.conversationService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "conversation service not configured") @@ -439,7 +415,7 @@ func (h *MessageHandler) ServeMedia(c echo.Context) error { } return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } - defer reader.Close() + defer func() { _ = reader.Close() }() contentType := asset.Mime if contentType == "" { contentType = "application/octet-stream" diff --git a/internal/handlers/message_test.go b/internal/handlers/message_test.go index 592d8df6..513104e2 100644 --- a/internal/handlers/message_test.go +++ b/internal/handlers/message_test.go @@ -11,7 +11,7 @@ import ( type testFlusher struct{} -func (f *testFlusher) Flush() {} +func (*testFlusher) Flush() {} func TestParseSinceParam(t *testing.T) { t.Parallel() diff --git a/internal/handlers/models.go b/internal/handlers/models.go index 7201bd87..cd944de9 100644 --- a/internal/handlers/models.go +++ b/internal/handlers/models.go @@ -47,7 +47,7 @@ func (h *ModelsHandler) Register(e *echo.Echo) { // @Success 201 {object} models.AddResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models [post] +// @Router /models [post]. func (h *ModelsHandler) Create(c echo.Context) error { var req models.AddRequest if err := c.Bind(&req); err != nil { @@ -73,7 +73,7 @@ func (h *ModelsHandler) Create(c echo.Context) error { // @Success 200 {array} models.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models [get] +// @Router /models [get]. func (h *ModelsHandler) List(c echo.Context) error { modelType := c.QueryParam("type") clientType := c.QueryParam("client_type") @@ -81,11 +81,12 @@ func (h *ModelsHandler) List(c echo.Context) error { var resp []models.GetResponse var err error - if modelType != "" { + switch { + case modelType != "": resp, err = h.service.ListByType(c.Request().Context(), models.ModelType(modelType)) - } else if clientType != "" { + case clientType != "": resp, err = h.service.ListByClientType(c.Request().Context(), models.ClientType(clientType)) - } else { + default: resp, err = h.service.List(c.Request().Context()) } @@ -104,7 +105,7 @@ func (h *ModelsHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/{id} [get] +// @Router /models/{id} [get]. func (h *ModelsHandler) GetByID(c echo.Context) error { id := c.Param("id") if id == "" { @@ -127,7 +128,7 @@ func (h *ModelsHandler) GetByID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/model/{modelId} [get] +// @Router /models/model/{modelId} [get]. func (h *ModelsHandler) GetByModelID(c echo.Context) error { modelID := c.Param("modelId") if modelID == "" { @@ -162,7 +163,7 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/{id} [put] +// @Router /models/{id} [put]. func (h *ModelsHandler) UpdateByID(c echo.Context) error { id := c.Param("id") if id == "" { @@ -194,7 +195,7 @@ func (h *ModelsHandler) UpdateByID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/model/{modelId} [put] +// @Router /models/model/{modelId} [put]. func (h *ModelsHandler) UpdateByModelID(c echo.Context) error { modelID := c.Param("modelId") if modelID == "" { @@ -236,7 +237,7 @@ func (h *ModelsHandler) UpdateByModelID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/{id} [delete] +// @Router /models/{id} [delete]. func (h *ModelsHandler) DeleteByID(c echo.Context) error { id := c.Param("id") if id == "" { @@ -258,7 +259,7 @@ func (h *ModelsHandler) DeleteByID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/model/{modelId} [delete] +// @Router /models/model/{modelId} [delete]. func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { modelID := c.Param("modelId") if modelID == "" { @@ -293,7 +294,7 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/{id}/test [post] +// @Router /models/{id}/test [post]. func (h *ModelsHandler) Test(c echo.Context) error { id := c.Param("id") if id == "" { @@ -319,7 +320,7 @@ func (h *ModelsHandler) Test(c echo.Context) error { // @Success 200 {object} models.CountResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/count [get] +// @Router /models/count [get]. func (h *ModelsHandler) Count(c echo.Context) error { modelType := c.QueryParam("type") diff --git a/internal/handlers/ping.go b/internal/handlers/ping.go index 49f88b27..6d6eab88 100644 --- a/internal/handlers/ping.go +++ b/internal/handlers/ping.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/labstack/echo/v4" + "github.com/memohai/memoh/internal/boot" ) @@ -35,7 +36,7 @@ func (h *PingHandler) Register(e *echo.Echo) { // @Summary Health check with server capabilities // @Tags system // @Success 200 {object} PingResponse -// @Router /ping [get] +// @Router /ping [get]. func (h *PingHandler) Ping(c echo.Context) error { return c.JSON(http.StatusOK, PingResponse{ Status: "ok", @@ -44,6 +45,6 @@ func (h *PingHandler) Ping(c echo.Context) error { }) } -func (h *PingHandler) PingHead(c echo.Context) error { +func (*PingHandler) PingHead(c echo.Context) error { return c.NoContent(http.StatusOK) } diff --git a/internal/handlers/preauth.go b/internal/handlers/preauth.go index eb6e6a90..3e81debe 100644 --- a/internal/handlers/preauth.go +++ b/internal/handlers/preauth.go @@ -63,7 +63,7 @@ func (h *PreauthHandler) Issue(c echo.Context) error { return c.JSON(http.StatusOK, key) } -func (h *PreauthHandler) requireUserID(c echo.Context) (string, error) { +func (*PreauthHandler) requireUserID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/providers.go b/internal/handlers/providers.go index d852999d..9f0ef13a 100644 --- a/internal/handlers/providers.go +++ b/internal/handlers/providers.go @@ -51,7 +51,7 @@ func (h *ProvidersHandler) Register(e *echo.Echo) { // @Success 201 {object} providers.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers [post] +// @Router /providers [post]. func (h *ProvidersHandler) Create(c echo.Context) error { var req providers.CreateRequest if err := c.Bind(&req); err != nil { @@ -82,7 +82,7 @@ func (h *ProvidersHandler) Create(c echo.Context) error { // @Produce json // @Success 200 {array} providers.GetResponse // @Failure 500 {object} ErrorResponse -// @Router /providers [get] +// @Router /providers [get]. func (h *ProvidersHandler) List(c echo.Context) error { resp, err := h.service.List(c.Request().Context()) if err != nil { @@ -103,7 +103,7 @@ func (h *ProvidersHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id} [get] +// @Router /providers/{id} [get]. func (h *ProvidersHandler) Get(c echo.Context) error { id := c.Param("id") if id == "" { @@ -128,7 +128,7 @@ func (h *ProvidersHandler) Get(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id}/models [get] +// @Router /providers/{id}/models [get]. func (h *ProvidersHandler) ListModelsByProvider(c echo.Context) error { if h.modelsService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "models service not configured") @@ -167,7 +167,7 @@ func (h *ProvidersHandler) ListModelsByProvider(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/name/{name} [get] +// @Router /providers/name/{name} [get]. func (h *ProvidersHandler) GetByName(c echo.Context) error { name := c.Param("name") if name == "" { @@ -194,7 +194,7 @@ func (h *ProvidersHandler) GetByName(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id} [put] +// @Router /providers/{id} [put]. func (h *ProvidersHandler) Update(c echo.Context) error { id := c.Param("id") if id == "" { @@ -225,7 +225,7 @@ func (h *ProvidersHandler) Update(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id} [delete] +// @Router /providers/{id} [delete]. func (h *ProvidersHandler) Delete(c echo.Context) error { id := c.Param("id") if id == "" { @@ -247,7 +247,7 @@ func (h *ProvidersHandler) Delete(c echo.Context) error { // @Produce json // @Success 200 {object} providers.CountResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/count [get] +// @Router /providers/count [get]. func (h *ProvidersHandler) Count(c echo.Context) error { count, err := h.service.Count(c.Request().Context()) if err != nil { @@ -268,7 +268,7 @@ func (h *ProvidersHandler) Count(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id}/test [post] +// @Router /providers/{id}/test [post]. func (h *ProvidersHandler) Test(c echo.Context) error { id := c.Param("id") if id == "" { @@ -298,7 +298,7 @@ func (h *ProvidersHandler) Test(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id}/import-models [post] +// @Router /providers/{id}/import-models [post]. func (h *ProvidersHandler) ImportModels(c echo.Context) error { id := c.Param("id") if id == "" { @@ -333,7 +333,6 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error { Type: models.ModelTypeChat, InputModalities: []string{models.ModelInputText}, }) - if err != nil { if errors.Is(err, models.ErrModelIDAlreadyExists) { resp.Skipped++ diff --git a/internal/handlers/schedule.go b/internal/handlers/schedule.go index 00eebcc5..30fa585c 100644 --- a/internal/handlers/schedule.go +++ b/internal/handlers/schedule.go @@ -46,7 +46,7 @@ func (h *ScheduleHandler) Register(e *echo.Echo) { // @Success 201 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule [post] +// @Router /bots/{bot_id}/schedule [post]. func (h *ScheduleHandler) Create(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -77,7 +77,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error { // @Success 200 {object} schedule.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule [get] +// @Router /bots/{bot_id}/schedule [get]. func (h *ScheduleHandler) List(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -106,7 +106,7 @@ func (h *ScheduleHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule/{id} [get] +// @Router /bots/{bot_id}/schedule/{id} [get]. func (h *ScheduleHandler) Get(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -142,7 +142,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error { // @Success 200 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule/{id} [put] +// @Router /bots/{bot_id}/schedule/{id} [put]. func (h *ScheduleHandler) Update(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -185,7 +185,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule/{id} [delete] +// @Router /bots/{bot_id}/schedule/{id} [delete]. func (h *ScheduleHandler) Delete(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -215,7 +215,7 @@ func (h *ScheduleHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *ScheduleHandler) requireUserID(c echo.Context) (string, error) { +func (*ScheduleHandler) requireUserID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/search_providers.go b/internal/handlers/search_providers.go index 9289fa8f..aeb30e43 100644 --- a/internal/handlers/search_providers.go +++ b/internal/handlers/search_providers.go @@ -37,7 +37,7 @@ func (h *SearchProvidersHandler) Register(e *echo.Echo) { // @Description List available search provider types and config schemas // @Tags search-providers // @Success 200 {array} searchproviders.ProviderMeta -// @Router /search-providers/meta [get] +// @Router /search-providers/meta [get]. func (h *SearchProvidersHandler) ListMeta(c echo.Context) error { return c.JSON(http.StatusOK, h.service.ListMeta(c.Request().Context())) } @@ -52,7 +52,7 @@ func (h *SearchProvidersHandler) ListMeta(c echo.Context) error { // @Success 201 {object} searchproviders.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers [post] +// @Router /search-providers [post]. func (h *SearchProvidersHandler) Create(c echo.Context) error { var req searchproviders.CreateRequest if err := c.Bind(&req); err != nil { @@ -80,7 +80,7 @@ func (h *SearchProvidersHandler) Create(c echo.Context) error { // @Param provider query string false "Provider filter (brave)" // @Success 200 {array} searchproviders.GetResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers [get] +// @Router /search-providers [get]. func (h *SearchProvidersHandler) List(c echo.Context) error { items, err := h.service.List(c.Request().Context(), c.QueryParam("provider")) if err != nil { @@ -99,7 +99,7 @@ func (h *SearchProvidersHandler) List(c echo.Context) error { // @Success 200 {object} searchproviders.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /search-providers/{id} [get] +// @Router /search-providers/{id} [get]. func (h *SearchProvidersHandler) Get(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -123,7 +123,7 @@ func (h *SearchProvidersHandler) Get(c echo.Context) error { // @Success 200 {object} searchproviders.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers/{id} [put] +// @Router /search-providers/{id} [put]. func (h *SearchProvidersHandler) Update(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -150,7 +150,7 @@ func (h *SearchProvidersHandler) Update(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers/{id} [delete] +// @Router /search-providers/{id} [delete]. func (h *SearchProvidersHandler) Delete(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index 717f417e..05e8526e 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -48,7 +48,7 @@ func (h *SettingsHandler) Register(e *echo.Echo) { // @Success 200 {object} settings.Settings // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/settings [get] +// @Router /bots/{bot_id}/settings [get]. func (h *SettingsHandler) Get(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -77,7 +77,7 @@ func (h *SettingsHandler) Get(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/settings [put] -// @Router /bots/{bot_id}/settings [post] +// @Router /bots/{bot_id}/settings [post]. func (h *SettingsHandler) Upsert(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -124,7 +124,7 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/settings [delete] +// @Router /bots/{bot_id}/settings [delete]. func (h *SettingsHandler) Delete(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -143,7 +143,7 @@ func (h *SettingsHandler) Delete(c echo.Context) error { return c.NoContent(http.StatusNoContent) } -func (h *SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*SettingsHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/skills.go b/internal/handlers/skills.go index f94ee4c7..a3f9787a 100644 --- a/internal/handlers/skills.go +++ b/internal/handlers/skills.go @@ -48,7 +48,7 @@ type skillsOpResponse struct { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/skills [get] +// @Router /bots/{bot_id}/container/skills [get]. func (h *ContainerdHandler) ListSkills(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -73,7 +73,7 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/skills [post] +// @Router /bots/{bot_id}/container/skills [post]. func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -120,7 +120,7 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/skills [delete] +// @Router /bots/{bot_id}/container/skills [delete]. func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { diff --git a/internal/handlers/subagent.go b/internal/handlers/subagent.go index 9067cef8..ebaf9b55 100644 --- a/internal/handlers/subagent.go +++ b/internal/handlers/subagent.go @@ -51,7 +51,7 @@ func (h *SubagentHandler) Register(e *echo.Echo) { // @Success 201 {object} subagent.Subagent // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents [post] +// @Router /bots/{bot_id}/subagents [post]. func (h *SubagentHandler) Create(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -82,7 +82,7 @@ func (h *SubagentHandler) Create(c echo.Context) error { // @Success 200 {object} subagent.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents [get] +// @Router /bots/{bot_id}/subagents [get]. func (h *SubagentHandler) List(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -111,7 +111,7 @@ func (h *SubagentHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id} [get] +// @Router /bots/{bot_id}/subagents/{id} [get]. func (h *SubagentHandler) Get(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -148,7 +148,7 @@ func (h *SubagentHandler) Get(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id} [put] +// @Router /bots/{bot_id}/subagents/{id} [put]. func (h *SubagentHandler) Update(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -192,7 +192,7 @@ func (h *SubagentHandler) Update(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id} [delete] +// @Router /bots/{bot_id}/subagents/{id} [delete]. func (h *SubagentHandler) Delete(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -231,7 +231,7 @@ func (h *SubagentHandler) Delete(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/context [get] +// @Router /bots/{bot_id}/subagents/{id}/context [get]. func (h *SubagentHandler) GetContext(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -268,7 +268,7 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/context [put] +// @Router /bots/{bot_id}/subagents/{id}/context [put]. func (h *SubagentHandler) UpdateContext(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -312,7 +312,7 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/skills [get] +// @Router /bots/{bot_id}/subagents/{id}/skills [get]. func (h *SubagentHandler) GetSkills(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -349,7 +349,7 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/skills [put] +// @Router /bots/{bot_id}/subagents/{id}/skills [put]. func (h *SubagentHandler) UpdateSkills(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -394,7 +394,7 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/skills [post] +// @Router /bots/{bot_id}/subagents/{id}/skills [post]. func (h *SubagentHandler) AddSkills(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -429,7 +429,7 @@ func (h *SubagentHandler) AddSkills(c echo.Context) error { return c.JSON(http.StatusOK, subagent.SkillsResponse{Skills: updated.Skills}) } -func (h *SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*SubagentHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/handlers/swagger.go b/internal/handlers/swagger.go index 4d4dec89..95261c61 100644 --- a/internal/handlers/swagger.go +++ b/internal/handlers/swagger.go @@ -34,7 +34,7 @@ func (h *SwaggerHandler) Register(e *echo.Echo) { e.GET("api/docs/", h.UI) } -func (h *SwaggerHandler) Spec(c echo.Context) error { +func (*SwaggerHandler) Spec(c echo.Context) error { swaggerOnce.Do(func() { swaggerSpec, swaggerErr = os.ReadFile("spec/swagger.json") }) @@ -44,7 +44,7 @@ func (h *SwaggerHandler) Spec(c echo.Context) error { return c.Blob(http.StatusOK, "application/json", swaggerSpec) } -func (h *SwaggerHandler) UI(c echo.Context) error { +func (*SwaggerHandler) UI(c echo.Context) error { return c.HTML(http.StatusOK, swaggerUIHTML) } diff --git a/internal/handlers/token_usage.go b/internal/handlers/token_usage.go index e9858a3b..72098fce 100644 --- a/internal/handlers/token_usage.go +++ b/internal/handlers/token_usage.go @@ -75,7 +75,7 @@ type TokenUsageResponse struct { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/token-usage [get] +// @Router /bots/{bot_id}/token-usage [get]. func (h *TokenUsageHandler) GetTokenUsage(c echo.Context) error { userID, err := RequireChannelIdentityID(c) if err != nil { diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 637980c3..4b33076b 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -93,7 +93,7 @@ func (h *UsersHandler) Register(e *echo.Echo) { // @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me [get] +// @Router /users/me [get]. func (h *UsersHandler) GetMe(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -114,7 +114,7 @@ func (h *UsersHandler) GetMe(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/identities [get] +// @Router /users/me/identities [get]. func (h *UsersHandler) ListMyIdentities(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -141,7 +141,7 @@ func (h *UsersHandler) ListMyIdentities(c echo.Context) error { // @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me [put] +// @Router /users/me [put]. func (h *UsersHandler) UpdateMe(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -166,7 +166,7 @@ func (h *UsersHandler) UpdateMe(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/password [put] +// @Router /users/me/password [put]. func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -193,7 +193,7 @@ func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users [get] +// @Router /users [get]. func (h *UsersHandler) ListUsers(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -226,7 +226,7 @@ func (h *UsersHandler) ListUsers(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/{id} [get] +// @Router /users/{id} [get]. func (h *UsersHandler) GetUser(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -266,7 +266,7 @@ func (h *UsersHandler) GetUser(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/{id} [put] +// @Router /users/{id} [put]. func (h *UsersHandler) UpdateUser(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -312,7 +312,7 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/{id}/password [put] +// @Router /users/{id}/password [put]. func (h *UsersHandler) ResetUserPassword(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -354,7 +354,7 @@ func (h *UsersHandler) ResetUserPassword(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users [post] +// @Router /users [post]. func (h *UsersHandler) CreateUser(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -371,6 +371,7 @@ func (h *UsersHandler) CreateUser(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } + //nolint:staticcheck // Keep backward-compatible behavior: CreateHuman creates backing user when owner id is empty. resp, err := h.service.CreateHuman(c.Request().Context(), "", req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -387,7 +388,7 @@ func (h *UsersHandler) CreateUser(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots [post] +// @Router /bots [post]. func (h *UsersHandler) CreateBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -458,7 +459,7 @@ func (h *UsersHandler) CreateBot(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots [get] +// @Router /bots [get]. func (h *UsersHandler) ListBots(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -496,7 +497,7 @@ func (h *UsersHandler) ListBots(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id} [get] +// @Router /bots/{id} [get]. func (h *UsersHandler) GetBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -523,7 +524,7 @@ func (h *UsersHandler) GetBot(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/checks [get] +// @Router /bots/{id}/checks [get]. func (h *UsersHandler) ListBotChecks(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -557,7 +558,7 @@ func (h *UsersHandler) ListBotChecks(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id} [put] +// @Router /bots/{id} [put]. func (h *UsersHandler) UpdateBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -592,7 +593,7 @@ func (h *UsersHandler) UpdateBot(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/owner [put] +// @Router /bots/{id}/owner [put]. func (h *UsersHandler) TransferBotOwner(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -636,7 +637,7 @@ func (h *UsersHandler) TransferBotOwner(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id} [delete] +// @Router /bots/{id} [delete]. func (h *UsersHandler) DeleteBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -671,7 +672,7 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/members [get] +// @Router /bots/{id}/members [get]. func (h *UsersHandler) ListBotMembers(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -702,7 +703,7 @@ func (h *UsersHandler) ListBotMembers(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/members [put] +// @Router /bots/{id}/members [put]. func (h *UsersHandler) UpsertBotMember(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -741,7 +742,7 @@ func (h *UsersHandler) UpsertBotMember(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/members/{user_id} [delete] +// @Router /bots/{id}/members/{user_id} [delete]. func (h *UsersHandler) DeleteBotMember(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -775,7 +776,7 @@ func (h *UsersHandler) DeleteBotMember(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform} [get] +// @Router /bots/{id}/channel/{platform} [get]. func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -817,7 +818,7 @@ func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform} [put] +// @Router /bots/{id}/channel/{platform} [put]. func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -867,7 +868,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform}/status [patch] +// @Router /bots/{id}/channel/{platform}/status [patch]. func (h *UsersHandler) UpdateBotChannelStatus(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -915,7 +916,7 @@ func (h *UsersHandler) UpdateBotChannelStatus(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform} [delete] +// @Router /bots/{id}/channel/{platform} [delete]. func (h *UsersHandler) DeleteBotChannelConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -953,7 +954,7 @@ func (h *UsersHandler) DeleteBotChannelConfig(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform}/send [post] +// @Router /bots/{id}/channel/{platform}/send [post]. func (h *UsersHandler) SendBotMessage(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -998,7 +999,7 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { // @Failure 401 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform}/send_chat [post] +// @Router /bots/{id}/channel/{platform}/send_chat [post]. func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { chatToken, err := auth.ChatTokenFromContext(c) if err != nil { @@ -1046,6 +1047,6 @@ func (h *UsersHandler) authorizeBotAccess(ctx context.Context, channelIdentityID return AuthorizeBotAccess(ctx, h.botService, h.service, channelIdentityID, botID, bots.AccessPolicy{AllowPublicMember: false}) } -func (h *UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) { +func (*UsersHandler) requireChannelIdentityID(c echo.Context) (string, error) { return RequireChannelIdentityID(c) } diff --git a/internal/healthcheck/adapter_test.go b/internal/healthcheck/adapter_test.go index 489dd66b..d0f3a133 100644 --- a/internal/healthcheck/adapter_test.go +++ b/internal/healthcheck/adapter_test.go @@ -9,7 +9,7 @@ type testChecker struct { items []CheckResult } -func (c *testChecker) ListChecks(ctx context.Context, botID string) []CheckResult { +func (c *testChecker) ListChecks(_ context.Context, _ string) []CheckResult { return c.items } diff --git a/internal/healthcheck/checkers/channel/checker.go b/internal/healthcheck/checkers/channel/checker.go index 3200ce32..61949bf3 100644 --- a/internal/healthcheck/checkers/channel/checker.go +++ b/internal/healthcheck/checkers/channel/checker.go @@ -40,9 +40,6 @@ func NewChecker(log *slog.Logger, observer ConnectionObserver) *Checker { // ListChecks evaluates channel connection statuses for a bot. func (c *Checker) ListChecks(ctx context.Context, botID string) []healthcheck.CheckResult { - if ctx == nil { - ctx = context.Background() - } // Connection observer is context-free; best effort early cancellation guard. if err := ctx.Err(); err != nil { return []healthcheck.CheckResult{} diff --git a/internal/healthcheck/checkers/channel/checker_test.go b/internal/healthcheck/checkers/channel/checker_test.go index ebe5ca3e..524fa262 100644 --- a/internal/healthcheck/checkers/channel/checker_test.go +++ b/internal/healthcheck/checkers/channel/checker_test.go @@ -2,7 +2,6 @@ package channelchecker import ( "context" - "io" "log/slog" "testing" "time" @@ -14,12 +13,12 @@ type fakeConnectionObserver struct { items []channel.ConnectionStatus } -func (f *fakeConnectionObserver) ConnectionStatusesByBot(botID string) []channel.ConnectionStatus { +func (f *fakeConnectionObserver) ConnectionStatusesByBot(_ string) []channel.ConnectionStatus { return f.items } func newTestLogger() *slog.Logger { - return slog.New(slog.NewTextHandler(io.Discard, nil)) + return slog.New(slog.DiscardHandler) } func TestCheckerListChecks(t *testing.T) { diff --git a/internal/healthcheck/checkers/mcp/checker.go b/internal/healthcheck/checkers/mcp/checker.go index 1de61551..848a3a53 100644 --- a/internal/healthcheck/checkers/mcp/checker.go +++ b/internal/healthcheck/checkers/mcp/checker.go @@ -53,9 +53,6 @@ func NewChecker(log *slog.Logger, connections ConnectionLister, tools ToolLister // ListChecks evaluates all active MCP connections for a bot. func (c *Checker) ListChecks(ctx context.Context, botID string) []healthcheck.CheckResult { - if ctx == nil { - ctx = context.Background() - } botID = strings.TrimSpace(botID) if botID == "" { return []healthcheck.CheckResult{} diff --git a/internal/healthcheck/checkers/mcp/checker_test.go b/internal/healthcheck/checkers/mcp/checker_test.go index 0535fe97..fa758bb7 100644 --- a/internal/healthcheck/checkers/mcp/checker_test.go +++ b/internal/healthcheck/checkers/mcp/checker_test.go @@ -3,7 +3,6 @@ package mcpchecker import ( "context" "errors" - "io" "log/slog" "testing" @@ -15,7 +14,7 @@ type fakeConnectionLister struct { err error } -func (f *fakeConnectionLister) ListActiveByBot(ctx context.Context, botID string) ([]mcp.Connection, error) { +func (f *fakeConnectionLister) ListActiveByBot(_ context.Context, _ string) ([]mcp.Connection, error) { if f.err != nil { return nil, f.err } @@ -27,7 +26,7 @@ type fakeToolLister struct { err error } -func (f *fakeToolLister) ListTools(ctx context.Context, session mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error) { +func (f *fakeToolLister) ListTools(_ context.Context, _ mcp.ToolSessionContext) ([]mcp.ToolDescriptor, error) { if f.err != nil { return nil, f.err } @@ -35,7 +34,7 @@ func (f *fakeToolLister) ListTools(ctx context.Context, session mcp.ToolSessionC } func newTestLogger() *slog.Logger { - return slog.New(slog.NewTextHandler(io.Discard, nil)) + return slog.New(slog.DiscardHandler) } func TestCheckerListChecks(t *testing.T) { diff --git a/internal/healthcheck/checkers/model/checker.go b/internal/healthcheck/checkers/model/checker.go index 3c52dea4..4ba93b16 100644 --- a/internal/healthcheck/checkers/model/checker.go +++ b/internal/healthcheck/checkers/model/checker.go @@ -64,9 +64,6 @@ type modelSlot struct { // ListChecks evaluates model health for a bot. func (c *Checker) ListChecks(ctx context.Context, botID string) []healthcheck.CheckResult { - if ctx == nil { - ctx = context.Background() - } botID = strings.TrimSpace(botID) if botID == "" { return nil diff --git a/internal/healthcheck/checkers/model/lookup.go b/internal/healthcheck/checkers/model/lookup.go index 8522cef0..e207bc1d 100644 --- a/internal/healthcheck/checkers/model/lookup.go +++ b/internal/healthcheck/checkers/model/lookup.go @@ -2,6 +2,7 @@ package modelchecker import ( "context" + "errors" "fmt" "strings" @@ -22,7 +23,7 @@ func NewQueriesLookup(queries *sqlc.Queries) *QueriesLookup { // GetBotModelIDs fetches model IDs configured directly on the bot. func (l *QueriesLookup) GetBotModelIDs(ctx context.Context, botID string) (BotModels, error) { if strings.TrimSpace(botID) == "" { - return BotModels{}, fmt.Errorf("bot id is required") + return BotModels{}, errors.New("bot id is required") } pgID, err := db.ParseUUID(botID) if err != nil { diff --git a/internal/heartbeat/service.go b/internal/heartbeat/service.go index 662477e0..721559e0 100644 --- a/internal/heartbeat/service.go +++ b/internal/heartbeat/service.go @@ -3,6 +3,7 @@ package heartbeat import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -46,7 +47,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, ru func (s *Service) Bootstrap(ctx context.Context) error { if s.queries == nil { - return fmt.Errorf("heartbeat queries not configured") + return errors.New("heartbeat queries not configured") } rows, err := s.queries.ListHeartbeatEnabledBots(ctx) if err != nil { @@ -60,7 +61,7 @@ func (s *Service) Bootstrap(ctx context.Context) error { OwnerUserID: ownerUserID, Interval: int(row.HeartbeatInterval), } - if err := s.scheduleJob(cfg); err != nil { + if err := s.scheduleJob(ctx, cfg); err != nil { s.logger.Error("failed to schedule heartbeat", slog.String("bot_id", botID), slog.Any("error", err)) } } @@ -87,7 +88,7 @@ func (s *Service) Reschedule(ctx context.Context, botID string) error { OwnerUserID: bot.OwnerUserID.String(), Interval: int(bot.HeartbeatInterval), } - return s.scheduleJob(cfg) + return s.scheduleJob(ctx, cfg) } func (s *Service) Stop(botID string) { @@ -186,7 +187,7 @@ func (s *Service) DeleteLogs(ctx context.Context, botID string) error { func (s *Service) generateTriggerToken(userID string) (string, error) { if strings.TrimSpace(s.jwtSecret) == "" { - return "", fmt.Errorf("jwt secret not configured") + return "", errors.New("jwt secret not configured") } signed, _, err := auth.GenerateToken(userID, s.jwtSecret, heartbeatTokenTTL) if err != nil { @@ -195,13 +196,13 @@ func (s *Service) generateTriggerToken(userID string) (string, error) { return "Bearer " + signed, nil } -func (s *Service) scheduleJob(cfg Config) error { +func (s *Service) scheduleJob(ctx context.Context, cfg Config) error { if cfg.Interval <= 0 { cfg.Interval = 30 } spec := fmt.Sprintf("@every %dm", cfg.Interval) job := func() { - s.runHeartbeat(context.Background(), cfg) + s.runHeartbeat(context.WithoutCancel(ctx), cfg) } entryID, err := s.cron.AddFunc(spec, job) if err != nil { diff --git a/internal/identity/user.go b/internal/identity/user.go index 6e5b9d41..d47871d9 100644 --- a/internal/identity/user.go +++ b/internal/identity/user.go @@ -2,6 +2,7 @@ package identity import ( "fmt" + "strings" ctr "github.com/memohai/memoh/internal/containerd" ) @@ -11,8 +12,9 @@ func ValidateChannelIdentityID(channelIdentityID string) error { if channelIdentityID == "" { return fmt.Errorf("%w: channel identity id required", ctr.ErrInvalidArgument) } + const allowedRunes = "-_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" for _, r := range channelIdentityID { - if !(r == '-' || r == '_' || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) { + if !strings.ContainsRune(allowedRunes, r) { return fmt.Errorf("%w: invalid channel identity id", ctr.ErrInvalidArgument) } } diff --git a/internal/inbox/service.go b/internal/inbox/service.go index 9f1d3633..8d4974a4 100644 --- a/internal/inbox/service.go +++ b/internal/inbox/service.go @@ -3,7 +3,9 @@ package inbox import ( "context" "encoding/json" + "fmt" "log/slog" + "math" "time" "github.com/google/uuid" @@ -131,10 +133,13 @@ func (s *Service) List(ctx context.Context, botID string, filter ListFilter) ([] if limit > 500 { limit = 500 } + if filter.Offset < 0 || filter.Offset > math.MaxInt32 { + return nil, fmt.Errorf("offset out of range: %d", filter.Offset) + } rows, err := s.queries.ListInboxItems(ctx, sqlc.ListInboxItemsParams{ BotID: botUUID, IsRead: boolOrNull(filter.IsRead), - Source: textOrNull(filter.Source), + Source: textOrNull(filter.Source), MaxCount: int32(limit), ItemOffset: int32(filter.Offset), }) diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 81afe073..a88d39f3 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -10,8 +10,8 @@ import ( type ctxKey struct{} var ( - L *slog.Logger = slog.Default() - logKey = ctxKey{} + L = slog.Default() + logKey = ctxKey{} ) // Init initializes the global logger with the given level and format (e.g. "debug", "json"). @@ -60,7 +60,18 @@ func parseLevel(level string) slog.Level { } // Debug, Info, Warn, Error log with the global logger (slog.Attr or key-value pairs). -func Debug(msg string, args ...any) { L.Debug(msg, args...) } -func Info(msg string, args ...any) { L.Info(msg, args...) } -func Warn(msg string, args ...any) { L.Warn(msg, args...) } -func Error(msg string, args ...any) { L.Error(msg, args...) } +func Debug(msg string, args ...any) { + L.Log(context.Background(), slog.LevelDebug, "global log", append([]any{slog.String("message", msg)}, args...)...) +} + +func Info(msg string, args ...any) { + L.Log(context.Background(), slog.LevelInfo, "global log", append([]any{slog.String("message", msg)}, args...)...) +} + +func Warn(msg string, args ...any) { + L.Log(context.Background(), slog.LevelWarn, "global log", append([]any{slog.String("message", msg)}, args...)...) +} + +func Error(msg string, args ...any) { + L.Log(context.Background(), slog.LevelError, "global log", append([]any{slog.String("message", msg)}, args...)...) +} diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index 72d3eea8..65fdcb88 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -21,7 +21,7 @@ func TestContextLogger(t *testing.T) { expectedKey := "request_id" expectedValue := "12345" - customLogger := L.With(expectedKey, expectedValue) + customLogger := L.With(slog.String(expectedKey, expectedValue)) ctx := WithContext(context.Background(), customLogger) extracted := FromContext(ctx) diff --git a/internal/mcp/connections.go b/internal/mcp/connections.go index 0a15ea28..462f02c6 100644 --- a/internal/mcp/connections.go +++ b/internal/mcp/connections.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -14,19 +15,19 @@ import ( // Connection represents a stored MCP connection for a bot. type Connection struct { - ID string `json:"id"` - BotID string `json:"bot_id"` - Name string `json:"name"` - Type string `json:"type"` - Config map[string]any `json:"config"` - Active bool `json:"is_active"` - Status string `json:"status"` + ID string `json:"id"` + BotID string `json:"bot_id"` + Name string `json:"name"` + Type string `json:"type"` + Config map[string]any `json:"config"` + Active bool `json:"is_active"` + Status string `json:"status"` ToolsCache []ToolDescriptor `json:"tools_cache"` - LastProbedAt *time.Time `json:"last_probed_at,omitempty"` - StatusMessage string `json:"status_message"` - AuthType string `json:"auth_type"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + LastProbedAt *time.Time `json:"last_probed_at,omitempty"` + StatusMessage string `json:"status_message"` + AuthType string `json:"auth_type"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // UpsertRequest accepts standard mcpServers item format. @@ -90,7 +91,7 @@ func NewConnectionService(log *slog.Logger, queries *sqlc.Queries) *ConnectionSe // ListByBot returns all MCP connections for a bot. func (s *ConnectionService) ListByBot(ctx context.Context, botID string) ([]Connection, error) { if s.queries == nil { - return nil, fmt.Errorf("mcp queries not configured") + return nil, errors.New("mcp queries not configured") } pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -129,7 +130,7 @@ func (s *ConnectionService) ListActiveByBot(ctx context.Context, botID string) ( // Get returns a specific MCP connection for a bot. func (s *ConnectionService) Get(ctx context.Context, botID, id string) (Connection, error) { if s.queries == nil { - return Connection{}, fmt.Errorf("mcp queries not configured") + return Connection{}, errors.New("mcp queries not configured") } pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -152,7 +153,7 @@ func (s *ConnectionService) Get(ctx context.Context, botID, id string) (Connecti // Create inserts a new MCP connection for a bot. func (s *ConnectionService) Create(ctx context.Context, botID string, req UpsertRequest) (Connection, error) { if s.queries == nil { - return Connection{}, fmt.Errorf("mcp queries not configured") + return Connection{}, errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -160,7 +161,7 @@ func (s *ConnectionService) Create(ctx context.Context, botID string, req Upsert } name := strings.TrimSpace(req.Name) if name == "" { - return Connection{}, fmt.Errorf("name is required") + return Connection{}, errors.New("name is required") } mcpType, config, err := inferTypeAndConfig(req) if err != nil { @@ -195,7 +196,7 @@ func (s *ConnectionService) Create(ctx context.Context, botID string, req Upsert // Update modifies an existing MCP connection. func (s *ConnectionService) Update(ctx context.Context, botID, id string, req UpsertRequest) (Connection, error) { if s.queries == nil { - return Connection{}, fmt.Errorf("mcp queries not configured") + return Connection{}, errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -207,7 +208,7 @@ func (s *ConnectionService) Update(ctx context.Context, botID, id string, req Up } name := strings.TrimSpace(req.Name) if name == "" { - return Connection{}, fmt.Errorf("name is required") + return Connection{}, errors.New("name is required") } mcpType, config, err := inferTypeAndConfig(req) if err != nil { @@ -246,7 +247,7 @@ func (s *ConnectionService) Update(ctx context.Context, botID, id string, req Up // Connections not in the input are left untouched. func (s *ConnectionService) Import(ctx context.Context, botID string, req ImportRequest) ([]Connection, error) { if s.queries == nil { - return nil, fmt.Errorf("mcp queries not configured") + return nil, errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -304,7 +305,7 @@ func (s *ConnectionService) ExportByBot(ctx context.Context, botID string) (Expo // Delete removes an MCP connection. func (s *ConnectionService) Delete(ctx context.Context, botID, id string) error { if s.queries == nil { - return fmt.Errorf("mcp queries not configured") + return errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -323,7 +324,7 @@ func (s *ConnectionService) Delete(ctx context.Context, botID, id string) error // BatchDelete removes multiple MCP connections by IDs. Invalid IDs are skipped; at least one must succeed for no error. func (s *ConnectionService) BatchDelete(ctx context.Context, botID string, ids []string) error { if s.queries == nil { - return fmt.Errorf("mcp queries not configured") + return errors.New("mcp queries not configured") } if len(ids) == 0 { return nil @@ -383,7 +384,7 @@ func decodeToolsCache(raw []byte) ([]ToolDescriptor, error) { // UpdateProbeResult persists the result of a probe operation. func (s *ConnectionService) UpdateProbeResult(ctx context.Context, botID, id, status string, tools []ToolDescriptor, message string) error { if s.queries == nil { - return fmt.Errorf("mcp queries not configured") + return errors.New("mcp queries not configured") } pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -426,10 +427,10 @@ func inferTypeAndConfig(req UpsertRequest) (string, map[string]any, error) { hasURL := strings.TrimSpace(req.URL) != "" if !hasCommand && !hasURL { - return "", nil, fmt.Errorf("command or url is required") + return "", nil, errors.New("command or url is required") } if hasCommand && hasURL { - return "", nil, fmt.Errorf("command and url are mutually exclusive") + return "", nil, errors.New("command and url are mutually exclusive") } config := map[string]any{} diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index b1223c62..c3a4f721 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "strings" "sync" "time" @@ -34,9 +33,9 @@ type Manager struct { logger *slog.Logger containerLockMu sync.Mutex containerLocks map[string]*sync.Mutex - mu sync.RWMutex - containerIPs map[string]string - grpcPool *mcpclient.Pool + mu sync.RWMutex + containerIPs map[string]string + grpcPool *mcpclient.Pool } func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, conn *pgxpool.Pool) *Manager { diff --git a/internal/mcp/mcpclient/client.go b/internal/mcp/mcpclient/client.go index d0a2c113..f8b9728b 100644 --- a/internal/mcp/mcpclient/client.go +++ b/internal/mcp/mcpclient/client.go @@ -7,15 +7,17 @@ package mcpclient import ( "bytes" "context" + "errors" "fmt" "io" "sync" - "github.com/memohai/memoh/internal/config" - pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" + + "github.com/memohai/memoh/internal/config" + pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" ) // Client wraps a gRPC connection to a single MCP container. @@ -36,7 +38,7 @@ func NewClientFromConn(conn *grpc.ClientConn) *Client { } // Dial creates a new Client connected to the given container IP. -func Dial(ctx context.Context, ip string) (*Client, error) { +func Dial(_ context.Context, ip string) (*Client, error) { target := fmt.Sprintf("%s:%d", ip, config.MCPGRPCPort) conn, err := grpc.NewClient(target, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -136,7 +138,7 @@ func (c *Client) ExecWithStdin(ctx context.Context, command, workDir string, tim for { msg, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { @@ -276,7 +278,7 @@ func (r *streamReader) Read(p []byte) (int, error) { return n, nil } -func (r *streamReader) Close() error { +func (*streamReader) Close() error { return nil } @@ -334,7 +336,7 @@ func (p *Pool) Get(ctx context.Context, botID string) (*Client, error) { p.mu.Lock() if existing, ok := p.clients[botID]; ok { p.mu.Unlock() - c.Close() + _ = c.Close() return existing, nil } p.clients[botID] = c @@ -346,7 +348,7 @@ func (p *Pool) Get(ctx context.Context, botID string) (*Client, error) { func (p *Pool) Remove(botID string) { p.mu.Lock() if c, ok := p.clients[botID]; ok { - c.Close() + _ = c.Close() delete(p.clients, botID) } p.mu.Unlock() @@ -356,7 +358,7 @@ func (p *Pool) Remove(botID string) { func (p *Pool) CloseAll() { p.mu.Lock() for id, c := range p.clients { - c.Close() + _ = c.Close() delete(p.clients, id) } p.mu.Unlock() diff --git a/internal/mcp/migrate.go b/internal/mcp/migrate.go index b7850395..a17e503d 100644 --- a/internal/mcp/migrate.go +++ b/internal/mcp/migrate.go @@ -91,11 +91,11 @@ func (m *Manager) migrateBindMountData(ctx context.Context, botID string) { } func copyFileToContainer(ctx context.Context, client *mcpclient.Client, hostPath, containerRelPath string) error { - f, err := os.Open(hostPath) + f, err := os.Open(hostPath) //nolint:gosec // G304: hostPath is an operator-configured migration asset path, not user input if err != nil { return err } - defer f.Close() + defer func() { _ = f.Close() }() containerRelPath = strings.ReplaceAll(containerRelPath, string(filepath.Separator), "/") _, err = client.WriteRaw(ctx, containerRelPath, f) diff --git a/internal/mcp/oauth.go b/internal/mcp/oauth.go index 186bd5b9..614cad56 100644 --- a/internal/mcp/oauth.go +++ b/internal/mcp/oauth.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -15,6 +16,7 @@ import ( "time" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -69,11 +71,11 @@ type AuthorizeResult struct { // Discover performs the MCP OAuth discovery flow: // 1. Send request to MCP server, expect 401 with WWW-Authenticate // 2. Fetch Protected Resource Metadata -// 3. Fetch Authorization Server Metadata +// 3. Fetch Authorization Server Metadata. func (s *OAuthService) Discover(ctx context.Context, serverURL string) (*DiscoveryResult, error) { serverURL = strings.TrimSpace(serverURL) if serverURL == "" { - return nil, fmt.Errorf("server URL is required") + return nil, errors.New("server URL is required") } resourceURI := canonicalResourceURI(serverURL) @@ -116,7 +118,7 @@ func (s *OAuthService) Discover(ctx context.Context, serverURL string) (*Discove if prmErr != nil { return nil, fmt.Errorf("failed to fetch protected resource metadata: %w", prmErr) } - return nil, fmt.Errorf("no authorization servers found in protected resource metadata") + return nil, errors.New("no authorization servers found in protected resource metadata") } // Step 3: Fetch Authorization Server Metadata @@ -183,7 +185,7 @@ func (s *OAuthService) StartAuthorization(ctx context.Context, connectionID, cli } if token.AuthorizationEndpoint == "" { - return nil, fmt.Errorf("authorization endpoint not configured") + return nil, errors.New("authorization endpoint not configured") } // Resolve client_id via priority chain @@ -221,7 +223,7 @@ func (s *OAuthService) StartAuthorization(ctx context.Context, connectionID, cli } } if clientID == "" { - return nil, fmt.Errorf("client_id is required: the authorization server does not support automatic registration, please provide a client_id from a registered OAuth application") + return nil, errors.New("client_id is required: the authorization server does not support automatic registration, please provide a client_id from a registered OAuth application") } // Persist client_secret if provided by the user @@ -277,7 +279,7 @@ func (s *OAuthService) StartAuthorization(ctx context.Context, connectionID, cli // HandleCallback exchanges the authorization code for tokens. func (s *OAuthService) HandleCallback(ctx context.Context, state, code string) (string, error) { if state == "" || code == "" { - return "", fmt.Errorf("state and code are required") + return "", errors.New("state and code are required") } token, err := s.queries.GetMCPOAuthTokenByState(ctx, state) @@ -286,7 +288,7 @@ func (s *OAuthService) HandleCallback(ctx context.Context, state, code string) ( } if token.TokenEndpoint == "" || token.PkceCodeVerifier == "" { - return "", fmt.Errorf("invalid OAuth state: missing token endpoint or code verifier") + return "", errors.New("invalid OAuth state: missing token endpoint or code verifier") } redirectURI := token.RedirectUri @@ -336,12 +338,12 @@ func (s *OAuthService) GetValidToken(ctx context.Context, connectionID string) ( } if token.AccessToken == "" { - return "", fmt.Errorf("no access token available, authorization required") + return "", errors.New("no access token available, authorization required") } if token.ExpiresAt.Valid && time.Now().After(token.ExpiresAt.Time.Add(-30*time.Second)) { if token.RefreshToken == "" { - return "", fmt.Errorf("access token expired and no refresh token available") + return "", errors.New("access token expired and no refresh token available") } refreshed, err := s.refreshToken(ctx, token.TokenEndpoint, token.RefreshToken, token.ClientID, token.ResourceUri) if err != nil { @@ -429,8 +431,8 @@ type authServerMetadata struct { } type tokenResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` //nolint:gosec // intentional: OAuth token response field + RefreshToken string `json:"refresh_token"` //nolint:gosec // intentional: OAuth token response field TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` Scope string `json:"scope"` @@ -445,12 +447,12 @@ func (s *OAuthService) probeForAuth(ctx context.Context, serverURL string) (reso } req.Header.Set("Content-Type", "application/json") - resp, err := s.httpClient.Do(req) + resp, err := s.httpClient.Do(req) //nolint:gosec // G704: URL is from OAuth server discovery metadata or operator config, not user input if err != nil { return "", "", err } - defer resp.Body.Close() - io.Copy(io.Discard, resp.Body) + defer func() { _ = resp.Body.Close() }() + _, _ = io.Copy(io.Discard, resp.Body) if resp.StatusCode != http.StatusUnauthorized { return "", "", fmt.Errorf("expected 401 Unauthorized, got %d (server may not require OAuth)", resp.StatusCode) @@ -466,7 +468,7 @@ func (s *OAuthService) probeForAuth(ctx context.Context, serverURL string) (reso return resourceMetaURL, scope, nil } -func (s *OAuthService) guessResourceMetadataURL(serverURL string) string { +func (*OAuthService) guessResourceMetadataURL(serverURL string) string { parsed, err := url.Parse(serverURL) if err != nil { return "" @@ -483,11 +485,11 @@ func (s *OAuthService) fetchProtectedResourceMetadata(ctx context.Context, metad if err != nil { return nil, err } - resp, err := s.httpClient.Do(req) + resp, err := s.httpClient.Do(req) //nolint:gosec // G704: URL is from OAuth server discovery metadata or operator config, not user input if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) @@ -545,11 +547,11 @@ func (s *OAuthService) tryFetchASMetadata(ctx context.Context, metadataURL strin if err != nil { return nil, err } - resp, err := s.httpClient.Do(req) + resp, err := s.httpClient.Do(req) //nolint:gosec // G704: URL is from OAuth server discovery metadata or operator config, not user input if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("metadata endpoint %s returned %d", metadataURL, resp.StatusCode) @@ -560,7 +562,7 @@ func (s *OAuthService) tryFetchASMetadata(ctx context.Context, metadataURL strin return nil, err } if meta.AuthorizationEndpoint == "" || meta.TokenEndpoint == "" { - return nil, fmt.Errorf("metadata missing required endpoints") + return nil, errors.New("metadata missing required endpoints") } return &meta, nil } @@ -596,11 +598,11 @@ func (s *OAuthService) exchangeCode(ctx context.Context, tokenEndpoint, code, co req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") - resp, err := s.httpClient.Do(req) + resp, err := s.httpClient.Do(req) //nolint:gosec // G704: URL is from OAuth server discovery metadata or operator config, not user input if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) if err != nil { @@ -630,7 +632,7 @@ func parseTokenResponse(body []byte) (*tokenResponse, error) { return nil, fmt.Errorf("%s", tok.Error) } if tok.AccessToken == "" { - return nil, fmt.Errorf("no access_token in response") + return nil, errors.New("no access_token in response") } if tok.TokenType == "" { tok.TokenType = "Bearer" @@ -659,7 +661,7 @@ func parseTokenResponse(body []byte) (*tokenResponse, error) { tok.TokenType = "Bearer" } if tok.AccessToken == "" { - return nil, fmt.Errorf("no access_token in response") + return nil, errors.New("no access_token in response") } return &tok, nil } @@ -688,11 +690,11 @@ func (s *OAuthService) refreshToken(ctx context.Context, tokenEndpoint, refreshT req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") - resp, err := s.httpClient.Do(req) + resp, err := s.httpClient.Do(req) //nolint:gosec // G704: URL is from OAuth server discovery metadata or operator config, not user input if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) if err != nil { @@ -722,7 +724,7 @@ type dcrRequest struct { type dcrResponse struct { ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` //nolint:gosec // intentional: OAuth Dynamic Client Registration response field } func (s *OAuthService) registerClient(ctx context.Context, registrationEndpoint, callbackURL string) (*dcrResponse, error) { @@ -744,11 +746,11 @@ func (s *OAuthService) registerClient(ctx context.Context, registrationEndpoint, } req.Header.Set("Content-Type", "application/json") - resp, err := s.httpClient.Do(req) + resp, err := s.httpClient.Do(req) //nolint:gosec // G704: URL is from OAuth server discovery metadata or operator config, not user input if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) @@ -760,7 +762,7 @@ func (s *OAuthService) registerClient(ctx context.Context, registrationEndpoint, return nil, fmt.Errorf("failed to decode DCR response: %w", err) } if result.ClientID == "" { - return nil, fmt.Errorf("DCR response missing client_id") + return nil, errors.New("DCR response missing client_id") } return &result, nil } diff --git a/internal/mcp/providers/container/provider.go b/internal/mcp/providers/container/provider.go index eafaf0f2..a571a092 100644 --- a/internal/mcp/providers/container/provider.go +++ b/internal/mcp/providers/container/provider.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log/slog" + "math" "strings" mcpgw "github.com/memohai/memoh/internal/mcp" @@ -47,7 +48,7 @@ func NewExecutor(log *slog.Logger, clients mcpclient.Provider, execWorkDir strin } // ListTools returns read, write, list, edit, and exec tool descriptors. -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { wd := p.execWorkDir if wd == "" { wd = defaultExecWorkDir @@ -199,6 +200,9 @@ func (p *Executor) callRead(ctx context.Context, client *mcpclient.Client, args if offset < 1 { return mcpgw.BuildToolErrorResult("line_offset must be >= 1"), nil } + if offset > math.MaxInt32 { + return mcpgw.BuildToolErrorResult("line_offset exceeds maximum"), nil + } lineOffset = int32(offset) } @@ -277,7 +281,7 @@ func (p *Executor) callEdit(ctx context.Context, client *mcpclient.Client, args if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer reader.Close() + defer func() { _ = reader.Close() }() raw, err := io.ReadAll(reader) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil diff --git a/internal/mcp/providers/container/provider_test.go b/internal/mcp/providers/container/provider_test.go index bbfc3bca..f9972f97 100644 --- a/internal/mcp/providers/container/provider_test.go +++ b/internal/mcp/providers/container/provider_test.go @@ -2,16 +2,19 @@ package container import ( "context" + "math" "net" + "strings" "sync" "testing" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + mcpgw "github.com/memohai/memoh/internal/mcp" "github.com/memohai/memoh/internal/mcp/mcpclient" pb "github.com/memohai/memoh/internal/mcp/mcpcontainer" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/test/bufconn" ) const bufSize = 1 << 20 @@ -53,7 +56,7 @@ func (f *fakeContainerService) ReadFile(_ context.Context, req *pb.ReadFileReque } content := string(data) lines := splitLines(content) - total := int32(len(lines)) + total := int32(min(len(lines), math.MaxInt32)) //nolint:gosec // G115: value is clamped to math.MaxInt32 above offset := req.GetLineOffset() if offset < 1 { @@ -73,12 +76,14 @@ func (f *fakeContainerService) ReadFile(_ context.Context, req *pb.ReadFileReque end = len(lines) } result := "" + var resultSb76 strings.Builder for i, l := range lines[start:end] { if i > 0 { - result += "\n" + resultSb76.WriteString("\n") } - result += l + resultSb76.WriteString(l) } + result += resultSb76.String() return &pb.ReadFileResponse{Content: result, TotalLines: total}, nil } @@ -186,7 +191,7 @@ func testSetup(t *testing.T, svc *fakeContainerService) mcpclient.Provider { if err != nil { t.Fatalf("grpc.NewClient: %v", err) } - t.Cleanup(func() { conn.Close() }) + t.Cleanup(func() { _ = conn.Close() }) client := mcpclient.NewClientFromConn(conn) return &staticProvider{client: client} diff --git a/internal/mcp/providers/container/prune.go b/internal/mcp/providers/container/prune.go index c448f6aa..a1cbdf99 100644 --- a/internal/mcp/providers/container/prune.go +++ b/internal/mcp/providers/container/prune.go @@ -1,9 +1,6 @@ package container import ( - "strings" - "unicode/utf8" - textprune "github.com/memohai/memoh/internal/prune" ) @@ -39,87 +36,3 @@ func pruneToolOutputText(text, label string) string { Marker: textprune.DefaultMarker, }) } - -// pruneReadOutput prunes read tool output. -func pruneReadOutput(text string) string { - return textprune.PruneWithEdges(text, "read output", textprune.Config{ - MaxBytes: readMaxBytes, - MaxLines: readMaxLines, - HeadBytes: readHeadBytes, - TailBytes: readTailBytes, - HeadLines: readHeadLines, - TailLines: readTailLines, - Marker: textprune.DefaultMarker, - }) -} - -// truncateLine truncates a line to maxLength runes (not bytes) and adds ellipsis if truncated. -func truncateLine(line string, maxLength int) string { - if maxLength <= 0 { - return line - } - - // Count runes, not bytes. - runeCount := utf8.RuneCountInString(line) - if runeCount <= maxLength { - return line - } - - // Find the byte position where we should cut (at maxLength runes). - bytePos := 0 - runes := 0 - for bytePos < len(line) && runes < maxLength { - _, size := utf8.DecodeRuneInString(line[bytePos:]) - bytePos += size - runes++ - } - - return line[:bytePos] + "..." -} - -// formatTruncatedLines formats a list of line numbers for display, collapsing consecutive numbers. -func formatTruncatedLines(lines []int) string { - if len(lines) == 0 { - return "" - } - if len(lines) == 1 { - return itoa(lines[0]) - } - if len(lines) <= 3 { - parts := make([]string, len(lines)) - for i, n := range lines { - parts[i] = itoa(n) - } - return strings.Join(parts, ", ") - } - // For many truncated lines, show count and examples. - return itoa(lines[0]) + ", " + itoa(lines[1]) + ", " + itoa(lines[2]) + "... (" + itoa(len(lines)) + " total)" -} - -// itoa converts int to string without allocation. -func itoa(n int) string { - if n == 0 { - return "0" - } - var buf [20]byte - i := len(buf) - sign := n < 0 - var u uint64 - if sign { - // Avoid overflow for MinInt. - u = uint64(-(n + 1)) - u++ - } else { - u = uint64(n) - } - for u > 0 { - i-- - buf[i] = byte('0' + u%10) - u /= 10 - } - if sign { - i-- - buf[i] = '-' - } - return string(buf[i:]) -} diff --git a/internal/mcp/providers/container/prune_test.go b/internal/mcp/providers/container/prune_test.go deleted file mode 100644 index cc41dfb6..00000000 --- a/internal/mcp/providers/container/prune_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package container - -import ( - "strconv" - "testing" -) - -func TestItoa_MatchesStrconv(t *testing.T) { - minInt := -int(^uint(0)>>1) - 1 - - tests := []int{ - 0, - 1, - -1, - 42, - -42, - 123456789, - -123456789, - minInt, - } - - for _, n := range tests { - got := itoa(n) - want := strconv.Itoa(n) - if got != want { - t.Fatalf("itoa(%d) = %q, want %q", n, got, want) - } - } -} diff --git a/internal/mcp/providers/email/provider.go b/internal/mcp/providers/email/provider.go index e808e185..a217809d 100644 --- a/internal/mcp/providers/email/provider.go +++ b/internal/mcp/providers/email/provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "math" "strconv" "strings" @@ -32,7 +33,7 @@ func NewExecutor(log *slog.Logger, service *email.Service, manager *email.Manage } } -func (e *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (*Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { return []mcpgw.ToolDescriptor{ { Name: toolListEmailAccounts, @@ -147,7 +148,7 @@ func (e *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex } } -func (e *Executor) callAccounts(_ context.Context, bindings []email.BindingResponse) (map[string]any, error) { +func (*Executor) callAccounts(_ context.Context, bindings []email.BindingResponse) (map[string]any, error) { accounts := make([]map[string]any, 0, len(bindings)) for _, b := range bindings { accounts = append(accounts, map[string]any{ @@ -250,6 +251,9 @@ func (e *Executor) callRead(ctx context.Context, providerID string, args map[str if uidRaw <= 0 { return mcpgw.BuildToolErrorResult("uid is required"), nil } + if uidRaw > math.MaxUint32 { + return mcpgw.BuildToolErrorResult("uid out of range"), nil + } providerName, config, err := e.service.ProviderConfig(ctx, providerID) if err != nil { diff --git a/internal/mcp/providers/inbox/provider.go b/internal/mcp/providers/inbox/provider.go index 3f569e41..79b97552 100644 --- a/internal/mcp/providers/inbox/provider.go +++ b/internal/mcp/providers/inbox/provider.go @@ -7,15 +7,14 @@ import ( "strings" "time" - mcpgw "github.com/memohai/memoh/internal/mcp" - inboxsvc "github.com/memohai/memoh/internal/inbox" + mcpgw "github.com/memohai/memoh/internal/mcp" ) const ( - toolSearchInbox = "search_inbox" - defaultSearchLimit = 20 - maxSearchLimit = 100 + toolSearchInbox = "search_inbox" + defaultSearchLimit = 20 + maxSearchLimit = 100 ) type Executor struct { @@ -33,7 +32,7 @@ func NewExecutor(log *slog.Logger, service *inboxsvc.Service) *Executor { } } -func (e *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (e *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if e.service == nil { return []mcpgw.ToolDescriptor{}, nil } diff --git a/internal/mcp/providers/memory/provider_test.go b/internal/mcp/providers/memory/provider_test.go index f12bc59d..2b637a3b 100644 --- a/internal/mcp/providers/memory/provider_test.go +++ b/internal/mcp/providers/memory/provider_test.go @@ -27,44 +27,56 @@ type fakeProvider struct { callErr error } -func (f *fakeProvider) Type() string { return "fake" } -func (f *fakeProvider) OnBeforeChat(_ context.Context, _ memprovider.BeforeChatRequest) (*memprovider.BeforeChatResult, error) { +func (*fakeProvider) Type() string { return "fake" } +func (*fakeProvider) OnBeforeChat(_ context.Context, _ memprovider.BeforeChatRequest) (*memprovider.BeforeChatResult, error) { return nil, nil } -func (f *fakeProvider) OnAfterChat(_ context.Context, _ memprovider.AfterChatRequest) error { + +func (*fakeProvider) OnAfterChat(_ context.Context, _ memprovider.AfterChatRequest) error { return nil } + func (f *fakeProvider) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { return f.tools, nil } + func (f *fakeProvider) CallTool(_ context.Context, _ mcpgw.ToolSessionContext, _ string, _ map[string]any) (map[string]any, error) { return f.callResp, f.callErr } -func (f *fakeProvider) Add(_ context.Context, _ memprovider.AddRequest) (memprovider.SearchResponse, error) { + +func (*fakeProvider) Add(_ context.Context, _ memprovider.AddRequest) (memprovider.SearchResponse, error) { return memprovider.SearchResponse{}, nil } -func (f *fakeProvider) Search(_ context.Context, _ memprovider.SearchRequest) (memprovider.SearchResponse, error) { + +func (*fakeProvider) Search(_ context.Context, _ memprovider.SearchRequest) (memprovider.SearchResponse, error) { return memprovider.SearchResponse{}, nil } -func (f *fakeProvider) GetAll(_ context.Context, _ memprovider.GetAllRequest) (memprovider.SearchResponse, error) { + +func (*fakeProvider) GetAll(_ context.Context, _ memprovider.GetAllRequest) (memprovider.SearchResponse, error) { return memprovider.SearchResponse{}, nil } -func (f *fakeProvider) Update(_ context.Context, _ memprovider.UpdateRequest) (memprovider.MemoryItem, error) { + +func (*fakeProvider) Update(_ context.Context, _ memprovider.UpdateRequest) (memprovider.MemoryItem, error) { return memprovider.MemoryItem{}, nil } -func (f *fakeProvider) Delete(_ context.Context, _ string) (memprovider.DeleteResponse, error) { + +func (*fakeProvider) Delete(_ context.Context, _ string) (memprovider.DeleteResponse, error) { return memprovider.DeleteResponse{}, nil } -func (f *fakeProvider) DeleteBatch(_ context.Context, _ []string) (memprovider.DeleteResponse, error) { + +func (*fakeProvider) DeleteBatch(_ context.Context, _ []string) (memprovider.DeleteResponse, error) { return memprovider.DeleteResponse{}, nil } -func (f *fakeProvider) DeleteAll(_ context.Context, _ memprovider.DeleteAllRequest) (memprovider.DeleteResponse, error) { + +func (*fakeProvider) DeleteAll(_ context.Context, _ memprovider.DeleteAllRequest) (memprovider.DeleteResponse, error) { return memprovider.DeleteResponse{}, nil } -func (f *fakeProvider) Compact(_ context.Context, _ map[string]any, _ float64, _ int) (memprovider.CompactResult, error) { + +func (*fakeProvider) Compact(_ context.Context, _ map[string]any, _ float64, _ int) (memprovider.CompactResult, error) { return memprovider.CompactResult{}, nil } -func (f *fakeProvider) Usage(_ context.Context, _ map[string]any) (memprovider.UsageResponse, error) { + +func (*fakeProvider) Usage(_ context.Context, _ map[string]any) (memprovider.UsageResponse, error) { return memprovider.UsageResponse{}, nil } diff --git a/internal/mcp/providers/message/provider.go b/internal/mcp/providers/message/provider.go index 87e6fb1d..d45857cc 100644 --- a/internal/mcp/providers/message/provider.go +++ b/internal/mcp/providers/message/provider.go @@ -3,7 +3,7 @@ package message import ( "context" "encoding/json" - "fmt" + "errors" "log/slog" "path/filepath" "strings" @@ -72,7 +72,7 @@ func NewExecutor(log *slog.Logger, sender Sender, reactor Reactor, resolver Chan } } -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { var tools []mcpgw.ToolDescriptor if p.sender != nil && p.resolver != nil { tools = append(tools, mcpgw.ToolDescriptor{ @@ -295,16 +295,16 @@ func (p *Executor) callReact(ctx context.Context, session mcpgw.ToolSessionConte // --- shared helpers --- -func (p *Executor) resolveBotID(arguments map[string]any, session mcpgw.ToolSessionContext) (string, error) { +func (*Executor) resolveBotID(arguments map[string]any, session mcpgw.ToolSessionContext) (string, error) { botID := mcpgw.FirstStringArg(arguments, "bot_id") if botID == "" { botID = strings.TrimSpace(session.BotID) } if botID == "" { - return "", fmt.Errorf("bot_id is required") + return "", errors.New("bot_id is required") } if strings.TrimSpace(session.BotID) != "" && botID != strings.TrimSpace(session.BotID) { - return "", fmt.Errorf("bot_id mismatch") + return "", errors.New("bot_id mismatch") } return botID, nil } @@ -315,7 +315,7 @@ func (p *Executor) resolvePlatform(arguments map[string]any, session mcpgw.ToolS platform = strings.TrimSpace(session.CurrentPlatform) } if platform == "" { - return "", fmt.Errorf("platform is required") + return "", errors.New("platform is required") } return p.resolver.ParseChannelType(platform) } @@ -496,14 +496,14 @@ func parseOutboundMessage(arguments map[string]any, fallbackText string) (channe return channel.Message{}, err } default: - return channel.Message{}, fmt.Errorf("message must be object or string") + return channel.Message{}, errors.New("message must be object or string") } } if msg.IsEmpty() && strings.TrimSpace(fallbackText) != "" { msg.Text = strings.TrimSpace(fallbackText) } if msg.IsEmpty() { - return channel.Message{}, fmt.Errorf("message is required") + return channel.Message{}, errors.New("message is required") } return msg, nil } diff --git a/internal/mcp/providers/message/provider_test.go b/internal/mcp/providers/message/provider_test.go index 62998dc1..9660ac12 100644 --- a/internal/mcp/providers/message/provider_test.go +++ b/internal/mcp/providers/message/provider_test.go @@ -14,7 +14,7 @@ type fakeSender struct { lastReq channel.SendRequest } -func (f *fakeSender) Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error { +func (f *fakeSender) Send(_ context.Context, _ string, _ channel.ChannelType, req channel.SendRequest) error { f.lastReq = req return f.err } @@ -24,7 +24,7 @@ type fakeReactor struct { lastReq channel.ReactRequest } -func (f *fakeReactor) React(ctx context.Context, botID string, channelType channel.ChannelType, req channel.ReactRequest) error { +func (f *fakeReactor) React(_ context.Context, _ string, _ channel.ChannelType, req channel.ReactRequest) error { f.lastReq = req return f.err } @@ -34,7 +34,7 @@ type fakeResolver struct { err error } -func (f *fakeResolver) ParseChannelType(raw string) (channel.ChannelType, error) { +func (f *fakeResolver) ParseChannelType(_ string) (channel.ChannelType, error) { if f.err != nil { return "", f.err } @@ -95,7 +95,7 @@ func TestExecutor_CallTool_NotFound(t *testing.T) { resolver := &fakeResolver{ct: channel.ChannelType("feishu")} exec := NewExecutor(nil, sender, nil, resolver, nil) _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) - if err != mcpgw.ErrToolNotFound { + if !errors.Is(err, mcpgw.ErrToolNotFound) { t.Errorf("expected ErrToolNotFound, got %v", err) } } diff --git a/internal/mcp/providers/schedule/provider.go b/internal/mcp/providers/schedule/provider.go index 156d908a..c8ad8a3a 100644 --- a/internal/mcp/providers/schedule/provider.go +++ b/internal/mcp/providers/schedule/provider.go @@ -40,7 +40,7 @@ func NewExecutor(log *slog.Logger, service Scheduler) *Executor { } } -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if p.service == nil { return []mcpgw.ToolDescriptor{}, nil } diff --git a/internal/mcp/providers/schedule/provider_test.go b/internal/mcp/providers/schedule/provider_test.go index 43d7d544..c456726c 100644 --- a/internal/mcp/providers/schedule/provider_test.go +++ b/internal/mcp/providers/schedule/provider_test.go @@ -21,32 +21,32 @@ type fakeScheduler struct { deleteErr error } -func (f *fakeScheduler) List(ctx context.Context, botID string) ([]sched.Schedule, error) { +func (f *fakeScheduler) List(_ context.Context, _ string) ([]sched.Schedule, error) { return f.list, nil } -func (f *fakeScheduler) Get(ctx context.Context, id string) (sched.Schedule, error) { +func (f *fakeScheduler) Get(_ context.Context, _ string) (sched.Schedule, error) { if f.getErr != nil { return sched.Schedule{}, f.getErr } return f.get, nil } -func (f *fakeScheduler) Create(ctx context.Context, botID string, req sched.CreateRequest) (sched.Schedule, error) { +func (f *fakeScheduler) Create(_ context.Context, _ string, _ sched.CreateRequest) (sched.Schedule, error) { if f.createErr != nil { return sched.Schedule{}, f.createErr } return f.create, nil } -func (f *fakeScheduler) Update(ctx context.Context, id string, req sched.UpdateRequest) (sched.Schedule, error) { +func (f *fakeScheduler) Update(_ context.Context, _ string, _ sched.UpdateRequest) (sched.Schedule, error) { if f.updateErr != nil { return sched.Schedule{}, f.updateErr } return f.update, nil } -func (f *fakeScheduler) Delete(ctx context.Context, id string) error { +func (f *fakeScheduler) Delete(_ context.Context, _ string) error { return f.deleteErr } @@ -83,7 +83,7 @@ func TestExecutor_CallTool_NotFound(t *testing.T) { svc := &fakeScheduler{} exec := NewExecutor(nil, svc) _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) - if err != mcpgw.ErrToolNotFound { + if !errors.Is(err, mcpgw.ErrToolNotFound) { t.Errorf("expected ErrToolNotFound, got %v", err) } } diff --git a/internal/mcp/providers/web/provider.go b/internal/mcp/providers/web/provider.go index 4a1cd798..8bf400e9 100644 --- a/internal/mcp/providers/web/provider.go +++ b/internal/mcp/providers/web/provider.go @@ -17,6 +17,7 @@ import ( "net/url" "regexp" "sort" + "strconv" "strings" "time" @@ -46,7 +47,7 @@ func NewExecutor(log *slog.Logger, settingsSvc *settings.Service, searchSvc *sea } } -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if p.settings == nil || p.searchProviders == nil { return []mcpgw.ToolDescriptor{}, nil } @@ -140,7 +141,7 @@ func (p *Executor) callWebSearch(ctx context.Context, providerName string, confi } } -func (p *Executor) callBraveSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callBraveSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := strings.TrimRight(firstNonEmpty(stringValue(cfg["base_url"]), "https://api.search.brave.com/res/v1/web/search"), "/") reqURL, err := url.Parse(endpoint) @@ -149,7 +150,7 @@ func (p *Executor) callBraveSearch(ctx context.Context, configJSON []byte, query } params := reqURL.Query() params.Set("q", query) - params.Set("count", fmt.Sprintf("%d", count)) + params.Set("count", strconv.Itoa(count)) reqURL.RawQuery = params.Encode() timeout := parseTimeout(configJSON, 15*time.Second) @@ -163,11 +164,11 @@ func (p *Executor) callBraveSearch(ctx context.Context, configJSON []byte, query if strings.TrimSpace(apiKey) != "" { req.Header.Set("X-Subscription-Token", strings.TrimSpace(apiKey)) } - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -201,7 +202,7 @@ func (p *Executor) callBraveSearch(ctx context.Context, configJSON []byte, query }), nil } -func (p *Executor) callBingSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callBingSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := strings.TrimRight(firstNonEmpty(stringValue(cfg["base_url"]), "https://api.bing.microsoft.com/v7.0/search"), "/") reqURL, err := url.Parse(endpoint) @@ -210,7 +211,7 @@ func (p *Executor) callBingSearch(ctx context.Context, configJSON []byte, query } params := reqURL.Query() params.Set("q", query) - params.Set("count", fmt.Sprintf("%d", count)) + params.Set("count", strconv.Itoa(count)) reqURL.RawQuery = params.Encode() timeout := parseTimeout(configJSON, 15*time.Second) @@ -224,11 +225,11 @@ func (p *Executor) callBingSearch(ctx context.Context, configJSON []byte, query if strings.TrimSpace(apiKey) != "" { req.Header.Set("Ocp-Apim-Subscription-Key", strings.TrimSpace(apiKey)) } - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -262,7 +263,7 @@ func (p *Executor) callBingSearch(ctx context.Context, configJSON []byte, query }), nil } -func (p *Executor) callGoogleSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callGoogleSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := strings.TrimRight(firstNonEmpty(stringValue(cfg["base_url"]), "https://customsearch.googleapis.com/customsearch/v1"), "/") reqURL, err := url.Parse(endpoint) @@ -279,7 +280,7 @@ func (p *Executor) callGoogleSearch(ctx context.Context, configJSON []byte, quer params := reqURL.Query() params.Set("q", query) params.Set("cx", cx) - params.Set("num", fmt.Sprintf("%d", count)) + params.Set("num", strconv.Itoa(count)) apiKey := stringValue(cfg["api_key"]) if strings.TrimSpace(apiKey) != "" { params.Set("key", strings.TrimSpace(apiKey)) @@ -293,11 +294,11 @@ func (p *Executor) callGoogleSearch(ctx context.Context, configJSON []byte, quer return mcpgw.BuildToolErrorResult(err.Error()), nil } req.Header.Set("Accept", "application/json") - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -329,7 +330,7 @@ func (p *Executor) callGoogleSearch(ctx context.Context, configJSON []byte, quer }), nil } -func (p *Executor) callTavilySearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callTavilySearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://api.tavily.com/search") apiKey := stringValue(cfg["api_key"]) @@ -349,11 +350,11 @@ func (p *Executor) callTavilySearch(ctx context.Context, configJSON []byte, quer req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -385,7 +386,7 @@ func (p *Executor) callTavilySearch(ctx context.Context, configJSON []byte, quer }), nil } -func (p *Executor) callSogouSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callSogouSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) host := firstNonEmpty(stringValue(cfg["base_url"]), "wsa.tencentcloudapi.com") secretID := stringValue(cfg["secret_id"]) @@ -403,7 +404,7 @@ func (p *Executor) callSogouSearch(ctx context.Context, configJSON []byte, query }) now := time.Now().UTC() - timestamp := fmt.Sprintf("%d", now.Unix()) + timestamp := strconv.FormatInt(now.Unix(), 10) date := now.Format("2006-01-02") hashedPayload := sha256Hex(payload) @@ -437,11 +438,11 @@ func (p *Executor) callSogouSearch(ctx context.Context, configJSON []byte, query req.Header.Set("X-TC-Action", action) req.Header.Set("X-TC-Version", version) req.Header.Set("X-TC-Timestamp", timestamp) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -517,7 +518,7 @@ func hmacSHA256(key, data []byte) []byte { return h.Sum(nil) } -func (p *Executor) callSerperSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callSerperSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://google.serper.dev/search") apiKey := stringValue(cfg["api_key"]) @@ -536,11 +537,11 @@ func (p *Executor) callSerperSearch(ctx context.Context, configJSON []byte, quer req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("X-API-KEY", apiKey) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -579,7 +580,7 @@ func (p *Executor) callSerperSearch(ctx context.Context, configJSON []byte, quer }), nil } -func (p *Executor) callSearXNGSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callSearXNGSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) baseURL := stringValue(cfg["base_url"]) if baseURL == "" { @@ -611,11 +612,11 @@ func (p *Executor) callSearXNGSearch(ctx context.Context, configJSON []byte, que return mcpgw.BuildToolErrorResult(err.Error()), nil } req.Header.Set("Accept", "application/json") - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -654,7 +655,7 @@ func (p *Executor) callSearXNGSearch(ctx context.Context, configJSON []byte, que }), nil } -func (p *Executor) callJinaSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callJinaSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://s.jina.ai/") apiKey := stringValue(cfg["api_key"]) @@ -678,11 +679,11 @@ func (p *Executor) callJinaSearch(ctx context.Context, configJSON []byte, query req.Header.Set("Accept", "application/json") req.Header.Set("X-Retain-Images", "none") req.Header.Set("Authorization", apiKey) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -714,7 +715,7 @@ func (p *Executor) callJinaSearch(ctx context.Context, configJSON []byte, query }), nil } -func (p *Executor) callExaSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callExaSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://api.exa.ai/search") apiKey := stringValue(cfg["api_key"]) @@ -739,11 +740,11 @@ func (p *Executor) callExaSearch(ctx context.Context, configJSON []byte, query s req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -775,7 +776,7 @@ func (p *Executor) callExaSearch(ctx context.Context, configJSON []byte, query s }), nil } -func (p *Executor) callBochaSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callBochaSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://api.bochaai.com/v1/web-search") apiKey := stringValue(cfg["api_key"]) @@ -797,11 +798,11 @@ func (p *Executor) callBochaSearch(ctx context.Context, configJSON []byte, query req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -837,7 +838,7 @@ func (p *Executor) callBochaSearch(ctx context.Context, configJSON []byte, query }), nil } -func (p *Executor) callDuckDuckGoSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callDuckDuckGoSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://html.duckduckgo.com/html/") @@ -853,11 +854,11 @@ func (p *Executor) callDuckDuckGoSearch(ctx context.Context, configJSON []byte, } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36") - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -925,7 +926,7 @@ func extractDDGURL(rawURL string) string { return rawURL } -func (p *Executor) callYandexSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { +func (*Executor) callYandexSearch(ctx context.Context, configJSON []byte, query string, count int) (map[string]any, error) { cfg := parseConfig(configJSON) endpoint := firstNonEmpty(stringValue(cfg["base_url"]), "https://searchapi.api.cloud.yandex.net/v2/web/search") apiKey := stringValue(cfg["api_key"]) @@ -952,11 +953,11 @@ func (p *Executor) callYandexSearch(ctx context.Context, configJSON []byte, quer } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Api-Key "+apiKey) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // G704: web browsing tool intentionally fetches user-specified URLs; SSRF is by design if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil @@ -986,7 +987,7 @@ func (p *Executor) callYandexSearch(ctx context.Context, configJSON []byte, quer type xmlInnerText string -func (t *xmlInnerText) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { +func (t *xmlInnerText) UnmarshalXML(d *xml.Decoder, _ xml.StartElement) error { var buf strings.Builder for { tok, err := d.Token() diff --git a/internal/mcp/sources/federation/source.go b/internal/mcp/sources/federation/source.go index 6fa7729f..84afaa3e 100644 --- a/internal/mcp/sources/federation/source.go +++ b/internal/mcp/sources/federation/source.go @@ -56,7 +56,7 @@ func NewSource(log *slog.Logger, gateway Gateway, connections ConnectionLister) log = slog.Default() } return &Source{ - logger: log.With(slog.String("source", "federated_mcp_tool")), + logger: log.With(slog.String("tool_source", "federated_mcp_tool")), gateway: gateway, connections: connections, cache: map[string]cacheEntry{}, diff --git a/internal/mcp/sources/federation/source_test.go b/internal/mcp/sources/federation/source_test.go index 9ba236ed..97e5bc90 100644 --- a/internal/mcp/sources/federation/source_test.go +++ b/internal/mcp/sources/federation/source_test.go @@ -13,7 +13,7 @@ type testConnectionLister struct { err error } -func (l *testConnectionLister) ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) { +func (l *testConnectionLister) ListActiveByBot(_ context.Context, _ string) ([]mcpgw.Connection, error) { if l.err != nil { return nil, l.err } @@ -28,29 +28,29 @@ type testGateway struct { lastCallType string } -func (g *testGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { +func (g *testGateway) ListHTTPConnectionTools(_ context.Context, _ mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { return g.listHTTP, nil } -func (g *testGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { +func (g *testGateway) CallHTTPConnectionTool(_ context.Context, _ mcpgw.Connection, _ string, _ map[string]any) (map[string]any, error) { g.lastCallType = "http" return map[string]any{"result": map[string]any{"ok": true, "route": "http"}}, nil } -func (g *testGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { +func (g *testGateway) ListSSEConnectionTools(_ context.Context, _ mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { return g.listSSE, nil } -func (g *testGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { +func (g *testGateway) CallSSEConnectionTool(_ context.Context, _ mcpgw.Connection, _ string, _ map[string]any) (map[string]any, error) { g.lastCallType = "sse" return map[string]any{"result": map[string]any{"ok": true, "route": "sse"}}, nil } -func (g *testGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { +func (g *testGateway) ListStdioConnectionTools(_ context.Context, _ string, _ mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { return g.listStdio, nil } -func (g *testGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { +func (g *testGateway) CallStdioConnectionTool(_ context.Context, _ string, _ mcpgw.Connection, _ string, _ map[string]any) (map[string]any, error) { g.lastCallType = "stdio" return map[string]any{"result": map[string]any{"ok": true, "route": "stdio"}}, nil } diff --git a/internal/mcp/tool_gateway_service.go b/internal/mcp/tool_gateway_service.go index ad202b59..2a55216d 100644 --- a/internal/mcp/tool_gateway_service.go +++ b/internal/mcp/tool_gateway_service.go @@ -3,7 +3,6 @@ package mcp import ( "context" "errors" - "fmt" "log/slog" "strings" "sync" @@ -55,7 +54,7 @@ func NewToolGatewayService(log *slog.Logger, executors []ToolExecutor, sources [ } } -func (s *ToolGatewayService) InitializeResult() map[string]any { +func (*ToolGatewayService) InitializeResult() map[string]any { return map[string]any{ "protocolVersion": "2025-06-18", "capabilities": map[string]any{ @@ -81,7 +80,7 @@ func (s *ToolGatewayService) ListTools(ctx context.Context, session ToolSessionC func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionContext, payload ToolCallPayload) (map[string]any, error) { toolName := strings.TrimSpace(payload.Name) if toolName == "" { - return nil, fmt.Errorf("tool name is required") + return nil, errors.New("tool name is required") } registry, err := s.getRegistry(ctx, session, false) @@ -121,7 +120,7 @@ func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionCo func (s *ToolGatewayService) getRegistry(ctx context.Context, session ToolSessionContext, force bool) (*ToolRegistry, error) { botID := strings.TrimSpace(session.BotID) if botID == "" { - return nil, fmt.Errorf("bot id is required") + return nil, errors.New("bot id is required") } if !force { s.mu.Lock() diff --git a/internal/mcp/tool_gateway_service_test.go b/internal/mcp/tool_gateway_service_test.go index 3509f7ef..46cfb807 100644 --- a/internal/mcp/tool_gateway_service_test.go +++ b/internal/mcp/tool_gateway_service_test.go @@ -13,11 +13,11 @@ type gatewayTestProvider struct { callErr map[string]error } -func (p *gatewayTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { +func (p *gatewayTestProvider) ListTools(_ context.Context, _ ToolSessionContext) ([]ToolDescriptor, error) { return p.tools, nil } -func (p *gatewayTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { +func (p *gatewayTestProvider) CallTool(_ context.Context, _ ToolSessionContext, toolName string, _ map[string]any) (map[string]any, error) { if err, ok := p.callErr[toolName]; ok { return nil, err } diff --git a/internal/mcp/tool_registry.go b/internal/mcp/tool_registry.go index edd7552c..1f6fdd1f 100644 --- a/internal/mcp/tool_registry.go +++ b/internal/mcp/tool_registry.go @@ -1,6 +1,7 @@ package mcp import ( + "errors" "fmt" "sort" "strings" @@ -24,11 +25,11 @@ func NewToolRegistry() *ToolRegistry { func (r *ToolRegistry) Register(executor ToolExecutor, tool ToolDescriptor) error { if executor == nil { - return fmt.Errorf("tool executor is required") + return errors.New("tool executor is required") } name := strings.TrimSpace(tool.Name) if name == "" { - return fmt.Errorf("tool name is required") + return errors.New("tool name is required") } if tool.InputSchema == nil { tool.InputSchema = map[string]any{ diff --git a/internal/mcp/tool_registry_test.go b/internal/mcp/tool_registry_test.go index f5001d9d..b20e9c19 100644 --- a/internal/mcp/tool_registry_test.go +++ b/internal/mcp/tool_registry_test.go @@ -7,11 +7,11 @@ import ( type registryTestProvider struct{} -func (p *registryTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { +func (*registryTestProvider) ListTools(_ context.Context, _ ToolSessionContext) ([]ToolDescriptor, error) { return nil, nil } -func (p *registryTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { +func (*registryTestProvider) CallTool(_ context.Context, _ ToolSessionContext, _ string, _ map[string]any) (map[string]any, error) { return nil, nil } diff --git a/internal/mcp/tool_types.go b/internal/mcp/tool_types.go index 9a556ec5..a89b0d10 100644 --- a/internal/mcp/tool_types.go +++ b/internal/mcp/tool_types.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "math" "strings" @@ -13,7 +14,7 @@ type ToolSessionContext struct { BotID string ChatID string ChannelIdentityID string - SessionToken string + SessionToken string `json:"-"` CurrentPlatform string ReplyTarget string } @@ -45,7 +46,7 @@ type ToolCallPayload struct { } // ErrToolNotFound indicates the provider does not own the requested tool. -var ErrToolNotFound = fmt.Errorf("tool not found") +var ErrToolNotFound = errors.New("tool not found") // BuildToolSuccessResult builds a standard MCP tool success result object. func BuildToolSuccessResult(structured any) map[string]any { @@ -148,39 +149,68 @@ func IntArg(arguments map[string]any, key string) (int, bool, error) { case int32: return int(value), true, nil case int64: - return int(value), true, nil + i, err := int64ToInt(value, key) + return i, true, err case uint: - return int(value), true, nil + i, err := uint64ToInt(uint64(value), key) + return i, true, err case uint8: return int(value), true, nil case uint16: return int(value), true, nil case uint32: - return int(value), true, nil + i, err := uint64ToInt(uint64(value), key) + return i, true, err case uint64: - return int(value), true, nil + i, err := uint64ToInt(value, key) + return i, true, err case float32: f := float64(value) if math.IsNaN(f) || math.IsInf(f, 0) { return 0, true, fmt.Errorf("%s must be a valid number", key) } - return int(f), true, nil + i, err := float64ToInt(f, key) + return i, true, err case float64: if math.IsNaN(value) || math.IsInf(value, 0) { return 0, true, fmt.Errorf("%s must be a valid number", key) } - return int(value), true, nil + i, err := float64ToInt(value, key) + return i, true, err case json.Number: i, err := value.Int64() if err != nil { return 0, true, fmt.Errorf("%s must be an integer", key) } - return int(i), true, nil + n, convErr := int64ToInt(i, key) + return n, true, convErr default: return 0, true, fmt.Errorf("%s must be a number", key) } } +func int64ToInt(value int64, key string) (int, error) { + if value < int64(math.MinInt) || value > int64(math.MaxInt) { + return 0, fmt.Errorf("%s out of range", key) + } + return int(value), nil +} + +func uint64ToInt(value uint64, key string) (int, error) { + const maxIntAsUint = uint64(math.MaxInt) + if value > maxIntAsUint { + return 0, fmt.Errorf("%s out of range", key) + } + return int(value), nil +} + +func float64ToInt(value float64, key string) (int, error) { + if value < float64(math.MinInt) || value > float64(math.MaxInt) { + return 0, fmt.Errorf("%s out of range", key) + } + return int(value), nil +} + func BoolArg(arguments map[string]any, key string) (bool, bool, error) { if arguments == nil { return false, false, nil diff --git a/internal/mcp/versioning.go b/internal/mcp/versioning.go index f5e83b5c..a16a3ba8 100644 --- a/internal/mcp/versioning.go +++ b/internal/mcp/versioning.go @@ -3,8 +3,10 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "log/slog" + "math" "strings" "time" @@ -13,7 +15,6 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/memohai/memoh/internal/config" - ctr "github.com/memohai/memoh/internal/containerd" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" @@ -55,7 +56,7 @@ type BotSnapshotData struct { func (m *Manager) CreateSnapshot(ctx context.Context, botID, snapshotName, source string) (*SnapshotCreateInfo, error) { if m.db == nil || m.queries == nil { - return nil, fmt.Errorf("db is not configured") + return nil, errors.New("db is not configured") } if err := validateBotID(botID); err != nil { return nil, err @@ -128,7 +129,7 @@ func (m *Manager) CreateSnapshot(ctx context.Context, botID, snapshotName, sourc func (m *Manager) CreateVersion(ctx context.Context, botID string) (*VersionInfo, error) { if m.db == nil || m.queries == nil { - return nil, fmt.Errorf("db is not configured") + return nil, errors.New("db is not configured") } if err := validateBotID(botID); err != nil { return nil, err @@ -254,7 +255,7 @@ func (m *Manager) ListBotSnapshotData(ctx context.Context, botID string) (*BotSn func (m *Manager) ListVersions(ctx context.Context, botID string) ([]VersionInfo, error) { if m.db == nil || m.queries == nil { - return nil, fmt.Errorf("db is not configured") + return nil, errors.New("db is not configured") } if err := validateBotID(botID); err != nil { return nil, err @@ -284,11 +285,14 @@ func (m *Manager) ListVersions(ctx context.Context, botID string) ([]VersionInfo func (m *Manager) RollbackVersion(ctx context.Context, botID string, version int) error { if m.db == nil || m.queries == nil { - return fmt.Errorf("db is not configured") + return errors.New("db is not configured") } if err := validateBotID(botID); err != nil { return err } + if version < 1 || version > math.MaxInt32 { + return errors.New("version out of range") + } containerID := m.containerID(botID) unlock := m.lockContainer(containerID) @@ -327,11 +331,14 @@ func (m *Manager) RollbackVersion(ctx context.Context, botID string, version int func (m *Manager) VersionSnapshotName(ctx context.Context, botID string, version int) (string, error) { if m.db == nil || m.queries == nil { - return "", fmt.Errorf("db is not configured") + return "", errors.New("db is not configured") } if err := validateBotID(botID); err != nil { return "", err } + if version < 1 || version > math.MaxInt32 { + return "", errors.New("version out of range") + } containerID := m.containerID(botID) return m.queries.GetVersionSnapshotRuntimeName(ctx, dbsqlc.GetVersionSnapshotRuntimeNameParams{ @@ -419,7 +426,7 @@ func (m *Manager) safeStopTask(ctx context.Context, containerID string) error { return err } -func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, runtime, imageRef string) (pgtype.UUID, error) { +func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, _ string, imageRef string) (pgtype.UUID, error) { botUUID, err := db.ParseUUID(botID) if err != nil { return pgtype.UUID{}, err @@ -459,7 +466,7 @@ func (m *Manager) recordSnapshotVersion(ctx context.Context, containerID, runtim if err != nil { return "", 0, time.Time{}, err } - defer tx.Rollback(ctx) + defer func() { _ = tx.Rollback(ctx) }() qtx := m.queries.WithTx(tx) diff --git a/internal/media/limits.go b/internal/media/limits.go index 35d128b3..415f0d34 100644 --- a/internal/media/limits.go +++ b/internal/media/limits.go @@ -1,6 +1,7 @@ package media import ( + "errors" "fmt" "io" ) @@ -13,10 +14,10 @@ const ( // ReadAllWithLimit reads from reader and rejects payloads larger than maxBytes. func ReadAllWithLimit(reader io.Reader, maxBytes int64) ([]byte, error) { if reader == nil { - return nil, fmt.Errorf("reader is required") + return nil, errors.New("reader is required") } if maxBytes <= 0 { - return nil, fmt.Errorf("max bytes must be greater than 0") + return nil, errors.New("max bytes must be greater than 0") } limited := &io.LimitedReader{ R: reader, diff --git a/internal/media/service.go b/internal/media/service.go index 34f89511..ed890cc7 100644 --- a/internal/media/service.go +++ b/internal/media/service.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "log/slog" @@ -39,22 +40,23 @@ func (s *Service) Ingest(ctx context.Context, input IngestInput) (Asset, error) return Asset{}, ErrProviderUnavailable } if strings.TrimSpace(input.BotID) == "" { - return Asset{}, fmt.Errorf("bot id is required") + return Asset{}, errors.New("bot id is required") } if input.Reader == nil { - return Asset{}, fmt.Errorf("reader is required") + return Asset{}, errors.New("reader is required") } maxBytes := input.MaxBytes if maxBytes <= 0 { maxBytes = MaxAssetBytes } - contentHash, sizeBytes, tempPath, err := spoolAndHashWithLimit(input.Reader, maxBytes) + contentHash, sizeBytes, tempFile, err := spoolAndHashWithLimit(input.Reader, maxBytes) if err != nil { return Asset{}, fmt.Errorf("read input: %w", err) } defer func() { - _ = os.Remove(tempPath) + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) //nolint:gosec // G703: path is from os.CreateTemp, not from user input }() mime := coalesce(input.Mime, "application/octet-stream") @@ -76,13 +78,6 @@ func (s *Service) Ingest(ctx context.Context, input IngestInput) (Asset, error) }, nil } - tempFile, err := os.Open(tempPath) - if err != nil { - return Asset{}, fmt.Errorf("open temp file: %w", err) - } - defer func() { - _ = tempFile.Close() - }() if err := s.provider.Put(ctx, routingKey, tempFile); err != nil { return Asset{}, fmt.Errorf("store media: %w", err) } @@ -153,13 +148,13 @@ func (s *Service) IngestContainerFile(ctx context.Context, botID, containerPath } opener, ok := s.provider.(storage.ContainerFileOpener) if !ok { - return Asset{}, fmt.Errorf("provider does not support container file reading") + return Asset{}, errors.New("provider does not support container file reading") } f, err := opener.OpenContainerFile(ctx, botID, containerPath) if err != nil { return Asset{}, fmt.Errorf("open container file: %w", err) } - defer f.Close() + defer func() { _ = f.Close() }() ext := path.Ext(containerPath) mime := mimeFromExtension(ext) return s.Ingest(ctx, IngestInput{BotID: botID, Mime: mime, Reader: f, OriginalExt: ext}) @@ -294,38 +289,42 @@ func coalesce(values ...string) string { return "" } -func spoolAndHashWithLimit(reader io.Reader, maxBytes int64) (string, int64, string, error) { +// spoolAndHashWithLimit streams reader into a temp file while computing its SHA-256. +// Returns the open file sought to the beginning; caller must close and remove it. +func spoolAndHashWithLimit(reader io.Reader, maxBytes int64) (contentHash string, size int64, f *os.File, err error) { if reader == nil { - return "", 0, "", fmt.Errorf("reader is required") + return "", 0, nil, errors.New("reader is required") } if maxBytes <= 0 { - return "", 0, "", fmt.Errorf("max bytes must be greater than 0") + return "", 0, nil, errors.New("max bytes must be greater than 0") } - tempFile, err := os.CreateTemp("", "memoh-media-*") - if err != nil { - return "", 0, "", fmt.Errorf("create temp file: %w", err) + tmp, createErr := os.CreateTemp("", "memoh-media-*") + if createErr != nil { + return "", 0, nil, fmt.Errorf("create temp file: %w", createErr) + } + cleanup := func() { + _ = tmp.Close() + _ = os.Remove(tmp.Name()) //nolint:gosec // G703: path is from os.CreateTemp, not from user input } - tempPath := tempFile.Name() - keepFile := false - defer func() { - _ = tempFile.Close() - if !keepFile { - _ = os.Remove(tempPath) - } - }() hasher := sha256.New() limited := &io.LimitedReader{R: reader, N: maxBytes + 1} - written, err := io.Copy(io.MultiWriter(tempFile, hasher), limited) - if err != nil { - return "", 0, "", fmt.Errorf("copy to temp file: %w", err) + written, copyErr := io.Copy(io.MultiWriter(tmp, hasher), limited) + if copyErr != nil { + cleanup() + return "", 0, nil, fmt.Errorf("copy to temp file: %w", copyErr) } if written > maxBytes { - return "", 0, "", fmt.Errorf("%w: max %d bytes", ErrAssetTooLarge, maxBytes) + cleanup() + return "", 0, nil, fmt.Errorf("%w: max %d bytes", ErrAssetTooLarge, maxBytes) } if written == 0 { - return "", 0, "", fmt.Errorf("asset payload is empty") + cleanup() + return "", 0, nil, errors.New("asset payload is empty") } - keepFile = true - return hex.EncodeToString(hasher.Sum(nil)), written, tempPath, nil + if _, seekErr := tmp.Seek(0, io.SeekStart); seekErr != nil { + cleanup() + return "", 0, nil, fmt.Errorf("seek temp file: %w", seekErr) + } + return hex.EncodeToString(hasher.Sum(nil)), written, tmp, nil } diff --git a/internal/memory/provider/builtin.go b/internal/memory/provider/builtin.go index 55f8437f..3ceb9e1c 100644 --- a/internal/memory/provider/builtin.go +++ b/internal/memory/provider/builtin.go @@ -2,7 +2,7 @@ package provider import ( "context" - "fmt" + "errors" "log/slog" "sort" "strings" @@ -65,7 +65,7 @@ func NewBuiltinProvider(log *slog.Logger, service any, chatAccessor conversation } } -func (p *BuiltinProvider) Type() string { return BuiltinType } +func (*BuiltinProvider) Type() string { return BuiltinType } // --- Conversation Hooks --- @@ -305,7 +305,7 @@ func (p *BuiltinProvider) canAccessChat(ctx context.Context, chatID, channelIden } } if p.chatAccessor == nil { - return false, fmt.Errorf("chat service not available") + return false, errors.New("chat service not available") } return p.chatAccessor.IsParticipant(ctx, chatID, channelIdentityID) } @@ -314,63 +314,63 @@ func (p *BuiltinProvider) canAccessChat(ctx context.Context, chatID, channelIden func (p *BuiltinProvider) Add(ctx context.Context, req AddRequest) (SearchResponse, error) { if p.service == nil { - return SearchResponse{}, fmt.Errorf("memory runtime not configured") + return SearchResponse{}, errors.New("memory runtime not configured") } return p.service.Add(ctx, req) } func (p *BuiltinProvider) Search(ctx context.Context, req SearchRequest) (SearchResponse, error) { if p.service == nil { - return SearchResponse{}, fmt.Errorf("memory runtime not configured") + return SearchResponse{}, errors.New("memory runtime not configured") } return p.service.Search(ctx, req) } func (p *BuiltinProvider) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error) { if p.service == nil { - return SearchResponse{}, fmt.Errorf("memory runtime not configured") + return SearchResponse{}, errors.New("memory runtime not configured") } return p.service.GetAll(ctx, req) } func (p *BuiltinProvider) Update(ctx context.Context, req UpdateRequest) (MemoryItem, error) { if p.service == nil { - return MemoryItem{}, fmt.Errorf("memory runtime not configured") + return MemoryItem{}, errors.New("memory runtime not configured") } return p.service.Update(ctx, req) } func (p *BuiltinProvider) Delete(ctx context.Context, memoryID string) (DeleteResponse, error) { if p.service == nil { - return DeleteResponse{}, fmt.Errorf("memory runtime not configured") + return DeleteResponse{}, errors.New("memory runtime not configured") } return p.service.Delete(ctx, memoryID) } func (p *BuiltinProvider) DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteResponse, error) { if p.service == nil { - return DeleteResponse{}, fmt.Errorf("memory runtime not configured") + return DeleteResponse{}, errors.New("memory runtime not configured") } return p.service.DeleteBatch(ctx, memoryIDs) } func (p *BuiltinProvider) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error) { if p.service == nil { - return DeleteResponse{}, fmt.Errorf("memory runtime not configured") + return DeleteResponse{}, errors.New("memory runtime not configured") } return p.service.DeleteAll(ctx, req) } func (p *BuiltinProvider) Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error) { if p.service == nil { - return CompactResult{}, fmt.Errorf("memory runtime not configured") + return CompactResult{}, errors.New("memory runtime not configured") } return p.service.Compact(ctx, filters, ratio, decayDays) } func (p *BuiltinProvider) Usage(ctx context.Context, filters map[string]any) (UsageResponse, error) { if p.service == nil { - return UsageResponse{}, fmt.Errorf("memory runtime not configured") + return UsageResponse{}, errors.New("memory runtime not configured") } return p.service.Usage(ctx, filters) } diff --git a/internal/memory/provider/registry.go b/internal/memory/provider/registry.go index d0df6cf6..392551ca 100644 --- a/internal/memory/provider/registry.go +++ b/internal/memory/provider/registry.go @@ -1,6 +1,7 @@ package provider import ( + "errors" "fmt" "log/slog" "strings" @@ -50,7 +51,7 @@ func (r *Registry) Register(id string, provider Provider) { func (r *Registry) Get(id string) (Provider, error) { id = strings.TrimSpace(id) if id == "" { - return nil, fmt.Errorf("provider id is required") + return nil, errors.New("provider id is required") } r.mu.RLock() p, ok := r.instances[id] diff --git a/internal/memory/provider/service.go b/internal/memory/provider/service.go index d4586148..5d21af8f 100644 --- a/internal/memory/provider/service.go +++ b/internal/memory/provider/service.go @@ -23,7 +23,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } -func (s *Service) ListMeta(_ context.Context) []ProviderMeta { +func (*Service) ListMeta(_ context.Context) []ProviderMeta { return []ProviderMeta{ { Provider: string(ProviderBuiltin), diff --git a/internal/memory/provider/types.go b/internal/memory/provider/types.go index 07a9632b..c0d27dde 100644 --- a/internal/memory/provider/types.go +++ b/internal/memory/provider/types.go @@ -23,7 +23,7 @@ type AfterChatRequest struct { Messages []Message } -// LLM is the interface for LLM operations needed by memory service +// LLM is the interface for LLM operations needed by memory service. type LLM interface { Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) diff --git a/internal/memory/service.go b/internal/memory/service.go index 840cedce..7021d7e6 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -177,7 +177,7 @@ func formatDayMarkdown(date string, records []writeRecord) string { }) var b strings.Builder - b.WriteString(fmt.Sprintf(memFileHeaderTemplate, date)) + fmt.Fprintf(&b, memFileHeaderTemplate, date) for _, r := range records { meta := map[string]string{"id": r.ID} if r.Topic != "" { diff --git a/internal/memory/service_test.go b/internal/memory/service_test.go index d1ed981c..d659254d 100644 --- a/internal/memory/service_test.go +++ b/internal/memory/service_test.go @@ -112,4 +112,3 @@ func TestRenderMemoryDayForDisplay_NonMemoryPathUnchanged(t *testing.T) { t.Fatalf("non-memory path should be unchanged, got: %s", out) } } - diff --git a/internal/memory/storefs/service.go b/internal/memory/storefs/service.go index e4c82db4..488980b7 100644 --- a/internal/memory/storefs/service.go +++ b/internal/memory/storefs/service.go @@ -80,7 +80,7 @@ func (s *Service) readFile(ctx context.Context, botID, filePath string) (string, if err != nil { return "", err } - defer reader.Close() + defer func() { _ = reader.Close() }() data, err := io.ReadAll(reader) if err != nil { return "", err @@ -449,6 +449,7 @@ func memoryDirPath() string { return path.Join(config.DefaultDataMount, "me func memoryDayPath(date string) string { return path.Join(memoryDirPath(), strings.TrimSpace(date)+".md") } + func memoryLegacyItemPath(id string) string { return path.Join(memoryDirPath(), strings.TrimSpace(id)+".md") } @@ -499,7 +500,7 @@ func formatMemoryDayMD(date string, items []MemoryItem) string { func parseMemoryDayMD(content string) ([]MemoryItem, error) { content = strings.TrimSpace(content) if content == "" { - return nil, fmt.Errorf("empty memory file") + return nil, errors.New("empty memory file") } lines := strings.Split(content, "\n") items := make([]MemoryItem, 0, 8) @@ -536,7 +537,7 @@ func parseMemoryDayMD(content string) ([]MemoryItem, error) { i = end } if len(items) == 0 { - return nil, fmt.Errorf("no memory entries found") + return nil, errors.New("no memory entries found") } return items, nil } @@ -544,11 +545,11 @@ func parseMemoryDayMD(content string) ([]MemoryItem, error) { func parseLegacyMemoryMD(content string) (MemoryItem, error) { content = strings.TrimSpace(content) if !strings.HasPrefix(content, "---") { - return MemoryItem{}, fmt.Errorf("missing frontmatter") + return MemoryItem{}, errors.New("missing frontmatter") } parts := strings.SplitN(content[3:], "---", 2) if len(parts) < 2 { - return MemoryItem{}, fmt.Errorf("incomplete frontmatter") + return MemoryItem{}, errors.New("incomplete frontmatter") } item := MemoryItem{Memory: strings.TrimSpace(parts[1])} for _, line := range strings.Split(strings.TrimSpace(parts[0]), "\n") { @@ -568,7 +569,7 @@ func parseLegacyMemoryMD(content string) (MemoryItem, error) { } } if item.ID == "" { - return MemoryItem{}, fmt.Errorf("missing id in frontmatter") + return MemoryItem{}, errors.New("missing id in frontmatter") } return item, nil } diff --git a/internal/message/service.go b/internal/message/service.go index cfb3acbb..70c3e1e8 100644 --- a/internal/message/service.go +++ b/internal/message/service.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log/slog" + "math" "strings" "time" @@ -104,6 +105,9 @@ func (s *DBService) Persist(ctx context.Context, input PersistInput) (Message, e s.logger.Warn("skip asset ref without content_hash") continue } + if ref.Ordinal < math.MinInt32 || ref.Ordinal > math.MaxInt32 { + return Message{}, fmt.Errorf("asset ordinal out of range: %d", ref.Ordinal) + } if _, assetErr := s.queries.CreateMessageAsset(ctx, sqlc.CreateMessageAssetParams{ MessageID: pgMsgID, Role: role, @@ -464,21 +468,12 @@ func coalesce(values ...string) string { return "" } -func toPgInt8(v int64) pgtype.Int8 { - if v == 0 { - return pgtype.Int8{} - } - return pgtype.Int8{Int64: v, Valid: true} -} - func parseJSONMap(data []byte) map[string]any { if len(data) == 0 { return nil } var m map[string]any - if err := json.Unmarshal(data, &m); err != nil { - slog.Warn("parseJSONMap: unmarshal failed", slog.Any("error", err)) - } + _ = json.Unmarshal(data, &m) return m } diff --git a/internal/models/models.go b/internal/models/models.go index 7554ca1a..59152398 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -5,25 +5,29 @@ import ( "errors" "fmt" "log/slog" + "math" "strings" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) -var ErrModelIDAlreadyExists = errors.New("model_id already exists") -var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers") +var ( + ErrModelIDAlreadyExists = errors.New("model_id already exists") + ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers") +) -// Service provides CRUD operations for models +// Service provides CRUD operations for models. type Service struct { queries *sqlc.Queries logger *slog.Logger } -// NewService creates a new models service +// NewService creates a new models service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -31,7 +35,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } -// Create adds a new model to the database +// Create adds a new model to the database. func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, error) { model := Model(req) if err := model.Validate(); err != nil { @@ -50,7 +54,7 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro } params := sqlc.CreateModelParams{ ModelID: model.ModelID, - LlmProviderID: llmProviderID, + LlmProviderID: llmProviderID, InputModalities: inputMod, SupportsReasoning: model.SupportsReasoning, Type: string(model.Type), @@ -66,7 +70,11 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro // Handle optional dimensions field (only for embedding models) if model.Type == ModelTypeEmbedding && model.Dimensions > 0 { - params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} + dimensions, err := intToInt4(model.Dimensions, "dimensions") + if err != nil { + return AddResponse{}, err + } + params.Dimensions = dimensions } created, err := s.queries.CreateModel(ctx, params) @@ -93,7 +101,7 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro }, nil } -// GetByID retrieves a model by its internal UUID +// GetByID retrieves a model by its internal UUID. func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) { uuid, err := db.ParseUUID(id) if err != nil { @@ -108,10 +116,10 @@ func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) { return convertToGetResponse(dbModel), nil } -// GetByModelID retrieves a model by its model_id field +// GetByModelID retrieves a model by its model_id field. func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse, error) { if modelID == "" { - return GetResponse{}, fmt.Errorf("model_id is required") + return GetResponse{}, errors.New("model_id is required") } dbModel, err := s.findUniqueByModelID(ctx, modelID) @@ -122,7 +130,7 @@ func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse return convertToGetResponse(dbModel), nil } -// List returns all models +// List returns all models. func (s *Service) List(ctx context.Context) ([]GetResponse, error) { dbModels, err := s.queries.ListModels(ctx) if err != nil { @@ -132,7 +140,7 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) { return convertToGetResponseList(dbModels), nil } -// ListByType returns models filtered by type (chat or embedding) +// ListByType returns models filtered by type (chat or embedding). func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) { if modelType != ModelTypeChat && modelType != ModelTypeEmbedding { return nil, fmt.Errorf("invalid model type: %s", modelType) @@ -146,7 +154,7 @@ func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetRes return convertToGetResponseList(dbModels), nil } -// ListByClientType returns models filtered by client type +// ListByClientType returns models filtered by client type. func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) { if !isValidClientType(clientType) { return nil, fmt.Errorf("invalid client type: %s", clientType) @@ -163,7 +171,7 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ( // ListByProviderID returns models filtered by provider ID. func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]GetResponse, error) { if strings.TrimSpace(providerID) == "" { - return nil, fmt.Errorf("provider id is required") + return nil, errors.New("provider id is required") } uuid, err := db.ParseUUID(providerID) if err != nil { @@ -182,7 +190,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string return nil, fmt.Errorf("invalid model type: %s", modelType) } if strings.TrimSpace(providerID) == "" { - return nil, fmt.Errorf("provider id is required") + return nil, errors.New("provider id is required") } uuid, err := db.ParseUUID(providerID) if err != nil { @@ -198,7 +206,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string return convertToGetResponseList(dbModels), nil } -// UpdateByID updates a model by its internal UUID +// UpdateByID updates a model by its internal UUID. func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { uuid, err := db.ParseUUID(id) if err != nil { @@ -236,7 +244,11 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) } if model.Type == ModelTypeEmbedding && model.Dimensions > 0 { - params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} + dimensions, err := intToInt4(model.Dimensions, "dimensions") + if err != nil { + return GetResponse{}, err + } + params.Dimensions = dimensions } updated, err := s.queries.UpdateModel(ctx, params) @@ -250,10 +262,10 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) return convertToGetResponse(updated), nil } -// UpdateByModelID updates a model by its model_id field +// UpdateByModelID updates a model by its model_id field. func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req UpdateRequest) (GetResponse, error) { if modelID == "" { - return GetResponse{}, fmt.Errorf("model_id is required") + return GetResponse{}, errors.New("model_id is required") } current, err := s.findUniqueByModelID(ctx, modelID) if err != nil { @@ -290,7 +302,11 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat } if model.Type == ModelTypeEmbedding && model.Dimensions > 0 { - params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} + dimensions, err := intToInt4(model.Dimensions, "dimensions") + if err != nil { + return GetResponse{}, err + } + params.Dimensions = dimensions } params.ModelID = model.ModelID @@ -306,7 +322,7 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat return convertToGetResponse(updated), nil } -// DeleteByID deletes a model by its internal UUID +// DeleteByID deletes a model by its internal UUID. func (s *Service) DeleteByID(ctx context.Context, id string) error { uuid, err := db.ParseUUID(id) if err != nil { @@ -320,10 +336,10 @@ func (s *Service) DeleteByID(ctx context.Context, id string) error { return nil } -// DeleteByModelID deletes a model by its model_id field +// DeleteByModelID deletes a model by its model_id field. func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error { if modelID == "" { - return fmt.Errorf("model_id is required") + return errors.New("model_id is required") } current, err := s.findUniqueByModelID(ctx, modelID) if err != nil { @@ -337,7 +353,7 @@ func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error { return nil } -// Count returns the total number of models +// Count returns the total number of models. func (s *Service) Count(ctx context.Context) (int64, error) { count, err := s.queries.CountModels(ctx) if err != nil { @@ -346,7 +362,7 @@ func (s *Service) Count(ctx context.Context) (int64, error) { return count, nil } -// CountByType returns the number of models of a specific type +// CountByType returns the number of models of a specific type. func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, error) { if modelType != ModelTypeChat && modelType != ModelTypeEmbedding { return 0, fmt.Errorf("invalid model type: %s", modelType) @@ -372,22 +388,22 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { }, } if dbModel.ClientType.Valid { - resp.Model.ClientType = ClientType(dbModel.ClientType.String) + resp.ClientType = ClientType(dbModel.ClientType.String) } - if resp.Model.Type == ModelTypeChat { - resp.Model.InputModalities = normalizeModalities(dbModel.InputModalities, []string{ModelInputText}) + if resp.Type == ModelTypeChat { + resp.InputModalities = normalizeModalities(dbModel.InputModalities, []string{ModelInputText}) } if dbModel.LlmProviderID.Valid { - resp.Model.LlmProviderID = dbModel.LlmProviderID.String() + resp.LlmProviderID = dbModel.LlmProviderID.String() } if dbModel.Name.Valid { - resp.Model.Name = dbModel.Name.String + resp.Name = dbModel.Name.String } if dbModel.Dimensions.Valid { - resp.Model.Dimensions = int(dbModel.Dimensions.Int32) + resp.Dimensions = int(dbModel.Dimensions.Int32) } return resp @@ -438,14 +454,14 @@ func isValidClientType(clientType ClientType) bool { // SelectMemoryModel selects a chat model for memory operations. func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) { if modelsService == nil { - return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + return GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured") } if queries == nil { - return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("queries not configured") + return GetResponse{}, sqlc.LlmProvider{}, errors.New("queries not configured") } candidates, err := modelsService.ListByType(ctx, ModelTypeChat) if err != nil || len(candidates) == 0 { - return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") + return GetResponse{}, sqlc.LlmProvider{}, errors.New("no chat models available for memory operations") } selected := candidates[0] provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID) @@ -465,7 +481,7 @@ func SelectMemoryModelForBot(ctx context.Context, modelsService *Service, querie // FetchProviderByID fetches a provider by ID. func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) { if strings.TrimSpace(providerID) == "" { - return sqlc.LlmProvider{}, fmt.Errorf("provider id missing") + return sqlc.LlmProvider{}, errors.New("provider id missing") } parsed, err := db.ParseUUID(providerID) if err != nil { @@ -473,3 +489,10 @@ func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID st } return queries.GetLlmProviderByID(ctx, parsed) } + +func intToInt4(value int, name string) (pgtype.Int4, error) { + if value < math.MinInt32 || value > math.MaxInt32 { + return pgtype.Int4{}, fmt.Errorf("%s out of range: %d", name, value) + } + return pgtype.Int4{Int32: int32(value), Valid: true}, nil +} diff --git a/internal/models/models_test.go b/internal/models/models_test.go index 426696e8..ba10db79 100644 --- a/internal/models/models_test.go +++ b/internal/models/models_test.go @@ -3,8 +3,9 @@ package models_test import ( "testing" - "github.com/memohai/memoh/internal/models" "github.com/stretchr/testify/assert" + + "github.com/memohai/memoh/internal/models" ) // This is an example test file demonstrating how to use the models service @@ -281,14 +282,14 @@ func TestModel_HasInputModality(t *testing.T) { func TestModelTypes(t *testing.T) { t.Run("ModelType constants", func(t *testing.T) { - assert.Equal(t, models.ModelType("chat"), models.ModelTypeChat) - assert.Equal(t, models.ModelType("embedding"), models.ModelTypeEmbedding) + assert.Equal(t, models.ModelTypeChat, models.ModelType("chat")) + assert.Equal(t, models.ModelTypeEmbedding, models.ModelType("embedding")) }) t.Run("ClientType constants", func(t *testing.T) { - assert.Equal(t, models.ClientType("openai-responses"), models.ClientTypeOpenAIResponses) - assert.Equal(t, models.ClientType("openai-completions"), models.ClientTypeOpenAICompletions) - assert.Equal(t, models.ClientType("anthropic-messages"), models.ClientTypeAnthropicMessages) - assert.Equal(t, models.ClientType("google-generative-ai"), models.ClientTypeGoogleGenerativeAI) + assert.Equal(t, models.ClientTypeOpenAIResponses, models.ClientType("openai-responses")) + assert.Equal(t, models.ClientTypeOpenAICompletions, models.ClientType("openai-completions")) + assert.Equal(t, models.ClientTypeAnthropicMessages, models.ClientType("anthropic-messages")) + assert.Equal(t, models.ClientTypeGoogleGenerativeAI, models.ClientType("google-generative-ai")) }) } diff --git a/internal/models/probe.go b/internal/models/probe.go index a1446d23..9a89dd0b 100644 --- a/internal/models/probe.go +++ b/internal/models/probe.go @@ -119,12 +119,12 @@ func probeReachable(ctx context.Context, baseURL string) (bool, string) { if err != nil { return false, err.Error() } - resp, err := http.DefaultClient.Do(req) + resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL if err != nil { return false, err.Error() } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() return true, "" } @@ -146,13 +146,13 @@ func doProbe(ctx context.Context, method, url string, headers map[string]string, } start := time.Now() - resp, err := http.DefaultClient.Do(req) + resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL latency := time.Since(start).Milliseconds() if err != nil { return probeResult{latencyMs: latency, message: err.Error()} } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() return probeResult{statusCode: resp.StatusCode, latencyMs: latency} } diff --git a/internal/models/types.go b/internal/models/types.go index 97869b01..0b1d518e 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -32,14 +32,14 @@ const ( ) type Model struct { - ModelID string `json:"model_id"` - Name string `json:"name"` - LlmProviderID string `json:"llm_provider_id"` - ClientType ClientType `json:"client_type,omitempty"` - InputModalities []string `json:"input_modalities,omitempty"` - SupportsReasoning bool `json:"supports_reasoning"` - Type ModelType `json:"type"` - Dimensions int `json:"dimensions"` + ModelID string `json:"model_id"` + Name string `json:"name"` + LlmProviderID string `json:"llm_provider_id"` + ClientType ClientType `json:"client_type,omitempty"` + InputModalities []string `json:"input_modalities,omitempty"` + SupportsReasoning bool `json:"supports_reasoning"` + Type ModelType `json:"type"` + Dimensions int `json:"dimensions"` } // validInputModalities is the set of recognised input modality tokens. diff --git a/internal/policy/service.go b/internal/policy/service.go index 2c476d1e..423653b9 100644 --- a/internal/policy/service.go +++ b/internal/policy/service.go @@ -2,7 +2,7 @@ package policy import ( "context" - "fmt" + "errors" "log/slog" "strings" @@ -36,11 +36,11 @@ func NewService(log *slog.Logger, botsService *bots.Service, settingsService *se // Resolve evaluates the full access policy for a bot. func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) { if s == nil || s.bots == nil || s.settings == nil { - return Decision{}, fmt.Errorf("policy service not configured") + return Decision{}, errors.New("policy service not configured") } botID = strings.TrimSpace(botID) if botID == "" { - return Decision{}, fmt.Errorf("bot id is required") + return Decision{}, errors.New("bot id is required") } bot, err := s.bots.Get(ctx, botID) if err != nil { @@ -82,7 +82,7 @@ func (s *Service) BotType(ctx context.Context, botID string) (string, error) { // BotOwnerUserID returns bot owner's user id. Implements router.PolicyService. func (s *Service) BotOwnerUserID(ctx context.Context, botID string) (string, error) { if s == nil || s.bots == nil { - return "", fmt.Errorf("policy service not configured") + return "", errors.New("policy service not configured") } bot, err := s.bots.Get(ctx, strings.TrimSpace(botID)) if err != nil { diff --git a/internal/preauth/service.go b/internal/preauth/service.go index f2abcc2a..139624ba 100644 --- a/internal/preauth/service.go +++ b/internal/preauth/service.go @@ -3,7 +3,6 @@ package preauth import ( "context" "errors" - "fmt" "strings" "time" @@ -28,7 +27,7 @@ func NewService(queries *sqlc.Queries) *Service { // Issue creates a new preauth key for the given bot. func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl time.Duration) (Key, error) { if s.queries == nil { - return Key{}, fmt.Errorf("preauth queries not configured") + return Key{}, errors.New("preauth queries not configured") } if ttl <= 0 { ttl = 24 * time.Hour @@ -61,7 +60,7 @@ func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl t func (s *Service) Get(ctx context.Context, token string) (Key, error) { if s.queries == nil { - return Key{}, fmt.Errorf("preauth queries not configured") + return Key{}, errors.New("preauth queries not configured") } row, err := s.queries.GetBotPreauthKey(ctx, strings.TrimSpace(token)) if err != nil { @@ -75,7 +74,7 @@ func (s *Service) Get(ctx context.Context, token string) (Key, error) { func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { if s.queries == nil { - return Key{}, fmt.Errorf("preauth queries not configured") + return Key{}, errors.New("preauth queries not configured") } pgID, err := db.ParseUUID(id) if err != nil { diff --git a/internal/providers/service.go b/internal/providers/service.go index dcc56d91..5ae74745 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -14,13 +14,13 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) -// Service handles provider operations +// Service handles provider operations. type Service struct { queries *sqlc.Queries logger *slog.Logger } -// NewService creates a new provider service +// NewService creates a new provider service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -28,7 +28,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } -// Create creates a new LLM provider +// Create creates a new LLM provider. func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) { // Marshal metadata metadataJSON, err := json.Marshal(req.Metadata) @@ -50,7 +50,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e return s.toGetResponse(provider), nil } -// Get retrieves a provider by ID +// Get retrieves a provider by ID. func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { providerID, err := db.ParseUUID(id) if err != nil { @@ -65,7 +65,7 @@ func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { return s.toGetResponse(provider), nil } -// GetByName retrieves a provider by name +// GetByName retrieves a provider by name. func (s *Service) GetByName(ctx context.Context, name string) (GetResponse, error) { provider, err := s.queries.GetLlmProviderByName(ctx, name) if err != nil { @@ -75,7 +75,7 @@ func (s *Service) GetByName(ctx context.Context, name string) (GetResponse, erro return s.toGetResponse(provider), nil } -// List retrieves all providers +// List retrieves all providers. func (s *Service) List(ctx context.Context) ([]GetResponse, error) { providers, err := s.queries.ListLlmProviders(ctx) if err != nil { @@ -89,7 +89,7 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) { return results, nil } -// Update updates an existing provider +// Update updates an existing provider. func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { providerID, err := db.ParseUUID(id) if err != nil { @@ -139,7 +139,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get return s.toGetResponse(updated), nil } -// Delete deletes a provider by ID +// Delete deletes a provider by ID. func (s *Service) Delete(ctx context.Context, id string) error { providerID, err := db.ParseUUID(id) if err != nil { @@ -152,7 +152,7 @@ func (s *Service) Delete(ctx context.Context, id string) error { return nil } -// Count returns the total count of providers +// Count returns the total count of providers. func (s *Service) Count(ctx context.Context) (int64, error) { count, err := s.queries.CountLlmProviders(ctx) if err != nil { @@ -215,11 +215,11 @@ func (s *Service) FetchRemoteModels(ctx context.Context, id string) ([]RemoteMod req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.ApiKey)) } - resp, err := http.DefaultClient.Do(req) + resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL if err != nil { return nil, fmt.Errorf("execute request: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -242,21 +242,23 @@ func probeReachable(ctx context.Context, baseURL string) (bool, string) { if err != nil { return false, err.Error() } - resp, err := http.DefaultClient.Do(req) + resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is from operator-configured LLM provider base URL if err != nil { return false, err.Error() } - io.Copy(io.Discard, resp.Body) - resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() return true, "" } -// toGetResponse converts a database provider to a response +// toGetResponse converts a database provider to a response. func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { var metadata map[string]any if len(provider.Metadata) > 0 { if err := json.Unmarshal(provider.Metadata, &metadata); err != nil { - slog.Warn("provider metadata unmarshal failed", slog.String("id", provider.ID.String()), slog.Any("error", err)) + if s.logger != nil { + s.logger.Warn("provider metadata unmarshal failed", slog.String("id", provider.ID.String()), slog.Any("error", err)) + } } } @@ -274,7 +276,7 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { } } -// maskAPIKey masks an API key for security +// maskAPIKey masks an API key for security. func maskAPIKey(apiKey string) string { if apiKey == "" { return "" diff --git a/internal/providers/types.go b/internal/providers/types.go index ef2b0473..bdf8e12c 100644 --- a/internal/providers/types.go +++ b/internal/providers/types.go @@ -2,40 +2,40 @@ package providers import "time" -// CreateRequest represents a request to create a new LLM provider +// CreateRequest represents a request to create a new LLM provider. type CreateRequest struct { Name string `json:"name" validate:"required"` BaseURL string `json:"base_url" validate:"required,url"` - APIKey string `json:"api_key"` + APIKey string `json:"api_key"` //nolint:gosec // intentional: LLM provider API key supplied by operator Metadata map[string]any `json:"metadata,omitempty"` } -// UpdateRequest represents a request to update an existing LLM provider +// UpdateRequest represents a request to update an existing LLM provider. type UpdateRequest struct { Name *string `json:"name,omitempty"` BaseURL *string `json:"base_url,omitempty"` - APIKey *string `json:"api_key,omitempty"` + APIKey *string `json:"api_key,omitempty"` //nolint:gosec // intentional: LLM provider API key update field Metadata map[string]any `json:"metadata,omitempty"` } -// GetResponse represents the response for getting a provider +// GetResponse represents the response for getting a provider. type GetResponse struct { ID string `json:"id"` Name string `json:"name"` BaseURL string `json:"base_url"` - APIKey string `json:"api_key,omitempty"` // masked in response + APIKey string `json:"api_key,omitempty"` //nolint:gosec // intentional: partially masked API key for display Metadata map[string]any `json:"metadata,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } -// ListResponse represents the response for listing providers +// ListResponse represents the response for listing providers. type ListResponse struct { Providers []GetResponse `json:"providers"` Total int64 `json:"total"` } -// CountResponse represents the count response +// CountResponse represents the count response. type CountResponse struct { Count int64 `json:"count"` } @@ -47,7 +47,7 @@ type TestResponse struct { Message string `json:"message,omitempty"` } -// RemoteModel represents a model returned by the provider's /v1/models endpoint +// RemoteModel represents a model returned by the provider's /v1/models endpoint. type RemoteModel struct { ID string `json:"id"` Object string `json:"object"` @@ -55,18 +55,18 @@ type RemoteModel struct { OwnedBy string `json:"owned_by"` } -// FetchModelsResponse represents the response from the provider's /v1/models endpoint +// FetchModelsResponse represents the response from the provider's /v1/models endpoint. type FetchModelsResponse struct { Object string `json:"object"` Data []RemoteModel `json:"data"` } -// ImportModelsRequest represents a request to import models from a provider +// ImportModelsRequest represents a request to import models from a provider. type ImportModelsRequest struct { ClientType string `json:"client_type"` } -// ImportModelsResponse represents the response for importing models +// ImportModelsResponse represents the response for importing models. type ImportModelsResponse struct { Created int `json:"created"` Skipped int `json:"skipped"` diff --git a/internal/schedule/service.go b/internal/schedule/service.go index 64042aa0..cb47979b 100644 --- a/internal/schedule/service.go +++ b/internal/schedule/service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "math" "strings" "sync" "time" @@ -48,14 +49,14 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, ru func (s *Service) Bootstrap(ctx context.Context) error { if s.queries == nil { - return fmt.Errorf("schedule queries not configured") + return errors.New("schedule queries not configured") } items, err := s.queries.ListEnabledSchedules(ctx) if err != nil { return err } for _, item := range items { - if err := s.scheduleJob(item); err != nil { + if err := s.scheduleJob(ctx, item); err != nil { return err } } @@ -64,10 +65,10 @@ func (s *Service) Bootstrap(ctx context.Context) error { func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) (Schedule, error) { if s.queries == nil { - return Schedule{}, fmt.Errorf("schedule queries not configured") + return Schedule{}, errors.New("schedule queries not configured") } if strings.TrimSpace(req.Name) == "" || strings.TrimSpace(req.Description) == "" || strings.TrimSpace(req.Pattern) == "" || strings.TrimSpace(req.Command) == "" { - return Schedule{}, fmt.Errorf("name, description, pattern, command are required") + return Schedule{}, errors.New("name, description, pattern, command are required") } if _, err := s.parser.Parse(req.Pattern); err != nil { return Schedule{}, fmt.Errorf("invalid cron pattern: %w", err) @@ -78,6 +79,9 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( } maxCalls := pgtype.Int4{Valid: false} if req.MaxCalls.Set && req.MaxCalls.Value != nil { + if *req.MaxCalls.Value < math.MinInt32 || *req.MaxCalls.Value > math.MaxInt32 { + return Schedule{}, fmt.Errorf("max_calls out of range: %d", *req.MaxCalls.Value) + } maxCalls = pgtype.Int4{Int32: int32(*req.MaxCalls.Value), Valid: true} } enabled := true @@ -97,7 +101,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( return Schedule{}, err } if row.Enabled { - if err := s.scheduleJob(row); err != nil { + if err := s.scheduleJob(ctx, row); err != nil { return Schedule{}, err } } @@ -112,7 +116,7 @@ func (s *Service) Get(ctx context.Context, id string) (Schedule, error) { row, err := s.queries.GetScheduleByID(ctx, pgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return Schedule{}, fmt.Errorf("schedule not found") + return Schedule{}, errors.New("schedule not found") } return Schedule{}, err } @@ -168,6 +172,9 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sch if req.MaxCalls.Value == nil { maxCalls = pgtype.Int4{Valid: false} } else { + if *req.MaxCalls.Value < math.MinInt32 || *req.MaxCalls.Value > math.MaxInt32 { + return Schedule{}, fmt.Errorf("max_calls out of range: %d", *req.MaxCalls.Value) + } maxCalls = pgtype.Int4{Int32: int32(*req.MaxCalls.Value), Valid: true} } } @@ -187,7 +194,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sch if err != nil { return Schedule{}, err } - if err := s.rescheduleJob(updated); err != nil { + if err := s.rescheduleJob(ctx, updated); err != nil { return Schedule{}, fmt.Errorf("reschedule job: %w", err) } return toSchedule(updated), nil @@ -207,14 +214,14 @@ func (s *Service) Delete(ctx context.Context, id string) error { func (s *Service) Trigger(ctx context.Context, scheduleID string) error { if s.triggerer == nil { - return fmt.Errorf("schedule triggerer not configured") + return errors.New("schedule triggerer not configured") } schedule, err := s.Get(ctx, scheduleID) if err != nil { return err } if !schedule.Enabled { - return fmt.Errorf("schedule is disabled") + return errors.New("schedule is disabled") } return s.runSchedule(ctx, schedule) } @@ -223,7 +230,7 @@ const scheduleTokenTTL = 10 * time.Minute func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error { if s.triggerer == nil { - return fmt.Errorf("schedule triggerer not configured") + return errors.New("schedule triggerer not configured") } updated, err := s.queries.IncrementScheduleCalls(ctx, toUUID(schedule.ID)) if err != nil { @@ -266,7 +273,7 @@ func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, er } ownerID := bot.OwnerUserID.String() if ownerID == "" { - return "", fmt.Errorf("bot owner not found") + return "", errors.New("bot owner not found") } return ownerID, nil } @@ -274,7 +281,7 @@ func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, er // generateTriggerToken creates a short-lived JWT for schedule trigger callbacks. func (s *Service) generateTriggerToken(userID string) (string, error) { if strings.TrimSpace(s.jwtSecret) == "" { - return "", fmt.Errorf("jwt secret not configured") + return "", errors.New("jwt secret not configured") } signed, _, err := auth.GenerateToken(userID, s.jwtSecret, scheduleTokenTTL) if err != nil { @@ -283,13 +290,13 @@ func (s *Service) generateTriggerToken(userID string) (string, error) { return "Bearer " + signed, nil } -func (s *Service) scheduleJob(schedule sqlc.Schedule) error { +func (s *Service) scheduleJob(ctx context.Context, schedule sqlc.Schedule) error { id := schedule.ID.String() if id == "" { - return fmt.Errorf("schedule id missing") + return errors.New("schedule id missing") } job := func() { - if err := s.runSchedule(context.Background(), toSchedule(schedule)); err != nil { + if err := s.runSchedule(context.WithoutCancel(ctx), toSchedule(schedule)); err != nil { s.logger.Error("scheduled job failed", slog.String("schedule_id", schedule.ID.String()), slog.Any("error", err)) } } @@ -303,14 +310,14 @@ func (s *Service) scheduleJob(schedule sqlc.Schedule) error { return nil } -func (s *Service) rescheduleJob(schedule sqlc.Schedule) error { +func (s *Service) rescheduleJob(ctx context.Context, schedule sqlc.Schedule) error { id := schedule.ID.String() if id == "" { return nil } s.removeJob(id) if schedule.Enabled { - return s.scheduleJob(schedule) + return s.scheduleJob(ctx, schedule) } return nil } @@ -337,8 +344,8 @@ func toSchedule(row sqlc.Schedule) Schedule { BotID: row.BotID.String(), } if row.MaxCalls.Valid { - max := int(row.MaxCalls.Int32) - item.MaxCalls = &max + maxCalls := int(row.MaxCalls.Int32) + item.MaxCalls = &maxCalls } if row.CreatedAt.Valid { item.CreatedAt = row.CreatedAt.Time diff --git a/internal/schedule/service_test.go b/internal/schedule/service_test.go index 15ec87fc..c9e9e9e5 100644 --- a/internal/schedule/service_test.go +++ b/internal/schedule/service_test.go @@ -1,7 +1,6 @@ package schedule import ( - "context" "log/slog" "strings" "testing" @@ -10,21 +9,6 @@ import ( "github.com/golang-jwt/jwt/v5" ) -type mockTriggerer struct { - called bool - botID string - payload TriggerPayload - token string -} - -func (m *mockTriggerer) TriggerSchedule(_ context.Context, botID string, payload TriggerPayload, token string) error { - m.called = true - m.botID = botID - m.payload = payload - m.token = token - return nil -} - func TestGenerateTriggerToken(t *testing.T) { secret := "test-secret-key-for-schedule" svc := &Service{ @@ -42,7 +26,7 @@ func TestGenerateTriggerToken(t *testing.T) { } raw := strings.TrimPrefix(tok, "Bearer ") - parsed, err := jwt.Parse(raw, func(token *jwt.Token) (any, error) { + parsed, err := jwt.Parse(raw, func(_ *jwt.Token) (any, error) { return []byte(secret), nil }) if err != nil { diff --git a/internal/searchproviders/service.go b/internal/searchproviders/service.go index 97433fe8..17a01b17 100644 --- a/internal/searchproviders/service.go +++ b/internal/searchproviders/service.go @@ -23,7 +23,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } -func (s *Service) ListMeta(_ context.Context) []ProviderMeta { +func (*Service) ListMeta(_ context.Context) []ProviderMeta { return []ProviderMeta{ { Provider: string(ProviderBrave), diff --git a/internal/settings/service.go b/internal/settings/service.go index 87c8326d..ce920512 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "math" "strings" "github.com/google/uuid" @@ -20,9 +21,11 @@ type Service struct { logger *slog.Logger } -var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access") -var ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers") -var ErrInvalidModelRef = errors.New("invalid model reference") +var ( + ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access") + ErrModelIDAmbiguous = errors.New("model_id is ambiguous across providers") + ErrInvalidModelRef = errors.New("invalid model reference") +) func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ @@ -45,7 +48,7 @@ func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) { func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) { if s.queries == nil { - return Settings{}, fmt.Errorf("settings queries not configured") + return Settings{}, errors.New("settings queries not configured") } pgID, err := db.ParseUUID(botID) if err != nil { @@ -122,6 +125,12 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest } memoryProviderUUID = providerID } + if current.MaxContextLoadTime < math.MinInt32 || current.MaxContextLoadTime > math.MaxInt32 || + current.MaxContextTokens < math.MinInt32 || current.MaxContextTokens > math.MaxInt32 || + current.MaxInboxItems < math.MinInt32 || current.MaxInboxItems > math.MaxInt32 || + current.HeartbeatInterval < math.MinInt32 || current.HeartbeatInterval > math.MaxInt32 { + return Settings{}, errors.New("settings numeric value out of int32 range") + } updated, err := s.queries.UpsertBotSettings(ctx, sqlc.UpsertBotSettingsParams{ ID: pgID, @@ -132,9 +141,9 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest AllowGuest: current.AllowGuest, ReasoningEnabled: current.ReasoningEnabled, ReasoningEffort: current.ReasoningEffort, - HeartbeatEnabled: current.HeartbeatEnabled, - HeartbeatInterval: int32(current.HeartbeatInterval), - HeartbeatPrompt: "", + HeartbeatEnabled: current.HeartbeatEnabled, + HeartbeatInterval: int32(current.HeartbeatInterval), + HeartbeatPrompt: "", ChatModelID: chatModelUUID, HeartbeatModelID: heartbeatModelUUID, SearchProviderID: searchProviderUUID, @@ -148,7 +157,7 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { - return fmt.Errorf("settings queries not configured") + return errors.New("settings queries not configured") } pgID, err := db.ParseUUID(botID) if err != nil { diff --git a/internal/settings/types.go b/internal/settings/types.go index cadaec4d..2abe9f33 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -9,9 +9,9 @@ const ( ) type Settings struct { - ChatModelID string `json:"chat_model_id"` - SearchProviderID string `json:"search_provider_id"` - MemoryProviderID string `json:"memory_provider_id"` + ChatModelID string `json:"chat_model_id"` + SearchProviderID string `json:"search_provider_id"` + MemoryProviderID string `json:"memory_provider_id"` MaxContextLoadTime int `json:"max_context_load_time"` MaxContextTokens int `json:"max_context_tokens"` MaxInboxItems int `json:"max_inbox_items"` @@ -19,15 +19,15 @@ type Settings struct { AllowGuest bool `json:"allow_guest"` ReasoningEnabled bool `json:"reasoning_enabled"` ReasoningEffort string `json:"reasoning_effort"` - HeartbeatEnabled bool `json:"heartbeat_enabled"` - HeartbeatInterval int `json:"heartbeat_interval"` - HeartbeatModelID string `json:"heartbeat_model_id"` + HeartbeatEnabled bool `json:"heartbeat_enabled"` + HeartbeatInterval int `json:"heartbeat_interval"` + HeartbeatModelID string `json:"heartbeat_model_id"` } type UpsertRequest struct { - ChatModelID string `json:"chat_model_id,omitempty"` - SearchProviderID string `json:"search_provider_id,omitempty"` - MemoryProviderID string `json:"memory_provider_id,omitempty"` + ChatModelID string `json:"chat_model_id,omitempty"` + SearchProviderID string `json:"search_provider_id,omitempty"` + MemoryProviderID string `json:"memory_provider_id,omitempty"` MaxContextLoadTime *int `json:"max_context_load_time,omitempty"` MaxContextTokens *int `json:"max_context_tokens,omitempty"` MaxInboxItems *int `json:"max_inbox_items,omitempty"` @@ -35,7 +35,7 @@ type UpsertRequest struct { AllowGuest *bool `json:"allow_guest,omitempty"` ReasoningEnabled *bool `json:"reasoning_enabled,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` - HeartbeatEnabled *bool `json:"heartbeat_enabled,omitempty"` - HeartbeatInterval *int `json:"heartbeat_interval,omitempty"` - HeartbeatModelID string `json:"heartbeat_model_id,omitempty"` + HeartbeatEnabled *bool `json:"heartbeat_enabled,omitempty"` + HeartbeatInterval *int `json:"heartbeat_interval,omitempty"` + HeartbeatModelID string `json:"heartbeat_model_id,omitempty"` } diff --git a/internal/storage/providers/containerfs/provider.go b/internal/storage/providers/containerfs/provider.go index 0ab1deec..e6eb120c 100644 --- a/internal/storage/providers/containerfs/provider.go +++ b/internal/storage/providers/containerfs/provider.go @@ -5,6 +5,7 @@ package containerfs import ( "context" + "errors" "fmt" "io" "path/filepath" @@ -71,7 +72,7 @@ func (p *Provider) Delete(ctx context.Context, key string) error { } // AccessPath returns the container-internal path for a storage key. -func (p *Provider) AccessPath(key string) string { +func (*Provider) AccessPath(key string) string { _, sub := splitRoutingKey(key) return filepath.Join("/data", containerMediaRoot, sub) } @@ -84,7 +85,7 @@ func (p *Provider) OpenContainerFile(ctx context.Context, botID, containerPath s } subPath := containerPath[len(dataPrefix):] if subPath == "" || strings.Contains(subPath, "..") { - return nil, fmt.Errorf("invalid container path") + return nil, errors.New("invalid container path") } client, err := p.clients.MCPClient(ctx, botID) if err != nil { diff --git a/internal/subagent/service.go b/internal/subagent/service.go index d0f20335..15991540 100644 --- a/internal/subagent/service.go +++ b/internal/subagent/service.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "log/slog" "strings" @@ -28,15 +27,15 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) (Subagent, error) { if s.queries == nil { - return Subagent{}, fmt.Errorf("subagent queries not configured") + return Subagent{}, errors.New("subagent queries not configured") } name := strings.TrimSpace(req.Name) if name == "" { - return Subagent{}, fmt.Errorf("name is required") + return Subagent{}, errors.New("name is required") } description := strings.TrimSpace(req.Description) if description == "" { - return Subagent{}, fmt.Errorf("description is required") + return Subagent{}, errors.New("description is required") } pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -76,7 +75,7 @@ func (s *Service) Get(ctx context.Context, id string) (Subagent, error) { row, err := s.queries.GetSubagentByID(ctx, pgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return Subagent{}, fmt.Errorf("subagent not found") + return Subagent{}, errors.New("subagent not found") } return Subagent{}, err } @@ -94,7 +93,7 @@ func (s *Service) GetByBotAndName(ctx context.Context, botID string, name string }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return Subagent{}, fmt.Errorf("subagent not found") + return Subagent{}, errors.New("subagent not found") } return Subagent{}, err } @@ -138,14 +137,14 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sub if req.Name != nil { name = strings.TrimSpace(*req.Name) if name == "" { - return Subagent{}, fmt.Errorf("name is required") + return Subagent{}, errors.New("name is required") } } description := existing.Description if req.Description != nil { description = strings.TrimSpace(*req.Description) if description == "" { - return Subagent{}, fmt.Errorf("description is required") + return Subagent{}, errors.New("description is required") } } metadata := existing.Metadata diff --git a/spec/docs.go b/spec/docs.go index 130e353f..470e8a90 100644 --- a/spec/docs.go +++ b/spec/docs.go @@ -10374,7 +10374,7 @@ const docTemplate = `{ } }` -// SwaggerInfo holds exported Swagger Info so clients can modify it +// SwaggerInfo holds exported Swagger Info so clients can modify it. var SwaggerInfo = &swag.Spec{ Version: "1.0.0", Host: "",