mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
260 lines
5.8 KiB
Go
260 lines
5.8 KiB
Go
package runtime
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/BurntSushi/toml"
|
|
|
|
"github.com/memohai/memoh/internal/config"
|
|
"github.com/memohai/memoh/internal/embedded"
|
|
)
|
|
|
|
type Manager struct {
|
|
log *slog.Logger
|
|
cfg config.Config
|
|
host string
|
|
port int
|
|
workdir string
|
|
cmd *exec.Cmd
|
|
stopOnce sync.Once
|
|
}
|
|
|
|
const (
|
|
defaultGatewayHost = "127.0.0.1"
|
|
defaultGatewayPort = 8081
|
|
agentConfigFileName = "config.toml"
|
|
agentBinName = "agent-bin"
|
|
agentUnavailableMarker = "UNAVAILABLE"
|
|
healthCheckTimeout = 30 * time.Second
|
|
healthCheckRetryBackoff = 400 * time.Millisecond
|
|
processStopTimeout = 5 * time.Second
|
|
)
|
|
|
|
func NewManager(log *slog.Logger, cfg config.Config) *Manager {
|
|
host := cfg.AgentGateway.Host
|
|
if host == "" {
|
|
host = defaultGatewayHost
|
|
}
|
|
port := cfg.AgentGateway.Port
|
|
if port == 0 {
|
|
port = defaultGatewayPort
|
|
}
|
|
return &Manager{
|
|
log: log.With(slog.String("component", "agent-runtime")),
|
|
cfg: cfg,
|
|
host: host,
|
|
port: port,
|
|
}
|
|
}
|
|
|
|
func (m *Manager) Start(ctx context.Context) error {
|
|
workdir, err := os.MkdirTemp("", "memoh-agent-runtime-*")
|
|
if err != nil {
|
|
return fmt.Errorf("create runtime temp dir: %w", err)
|
|
}
|
|
m.workdir = workdir
|
|
|
|
agentFS, err := embedded.AgentFS()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
agentDir := filepath.Join(workdir, "agent")
|
|
if err := extractFS(agentFS, agentDir); err != nil {
|
|
return fmt.Errorf("extract agent assets: %w", err)
|
|
}
|
|
|
|
agentBinPath := filepath.Join(agentDir, agentBinaryNameForRuntime())
|
|
if _, err := os.Stat(agentBinPath); err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
markerPath := filepath.Join(agentDir, agentUnavailableMarker)
|
|
if _, markerErr := os.Stat(markerPath); markerErr == nil {
|
|
m.log.Warn("bundled agent binary unavailable for current platform; falling back to configured agent gateway", slog.String("platform", runtimePlatform()))
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("agent binary missing: %w", err)
|
|
}
|
|
if err := os.Chmod(agentBinPath, 0o755); err != nil {
|
|
return fmt.Errorf("chmod agent binary: %w", err)
|
|
}
|
|
agentConfigPath := filepath.Join(agentDir, agentConfigFileName)
|
|
if err := writeAgentConfig(agentConfigPath, m.cfg); err != nil {
|
|
return err
|
|
}
|
|
|
|
cmd := exec.Command(agentBinPath)
|
|
cmd.Dir = agentDir
|
|
cmd.Env = append(
|
|
os.Environ(),
|
|
"MEMOH_CONFIG_PATH="+agentConfigPath,
|
|
"CONFIG_PATH="+agentConfigPath,
|
|
)
|
|
cmd.Stdout = &logWriter{log: m.log, level: slog.LevelInfo}
|
|
cmd.Stderr = &logWriter{log: m.log, level: slog.LevelError}
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
return fmt.Errorf("start bundled agent runtime: %w", err)
|
|
}
|
|
m.cmd = cmd
|
|
|
|
m.log.Info("bundled agent runtime started", slog.Int("pid", cmd.Process.Pid), slog.String("addr", m.address()))
|
|
if err := m.waitHealthy(ctx); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) Stop(ctx context.Context) error {
|
|
var retErr error
|
|
m.stopOnce.Do(func() {
|
|
if m.cmd == nil || m.cmd.Process == nil {
|
|
return
|
|
}
|
|
|
|
_ = m.cmd.Process.Signal(os.Interrupt)
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- m.cmd.Wait()
|
|
}()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err != nil && !errors.Is(err, syscall.EINTR) {
|
|
retErr = err
|
|
}
|
|
case <-ctx.Done():
|
|
_ = m.cmd.Process.Kill()
|
|
retErr = ctx.Err()
|
|
case <-time.After(processStopTimeout):
|
|
_ = m.cmd.Process.Kill()
|
|
<-done
|
|
}
|
|
|
|
if m.workdir != "" {
|
|
_ = os.RemoveAll(m.workdir)
|
|
}
|
|
})
|
|
return retErr
|
|
}
|
|
|
|
func (m *Manager) waitHealthy(ctx context.Context) error {
|
|
client := &http.Client{Timeout: 2 * time.Second}
|
|
healthURL := fmt.Sprintf("http://%s/health", m.address())
|
|
deadline := time.Now().Add(healthCheckTimeout)
|
|
for time.Now().Before(deadline) {
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil)
|
|
resp, err := client.Do(req)
|
|
if err == nil {
|
|
_ = resp.Body.Close()
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
return nil
|
|
}
|
|
}
|
|
time.Sleep(healthCheckRetryBackoff)
|
|
}
|
|
return fmt.Errorf("bundled agent runtime health check timeout: %s", healthURL)
|
|
}
|
|
|
|
func (m *Manager) address() string {
|
|
return fmt.Sprintf("%s:%d", m.host, m.port)
|
|
}
|
|
|
|
func extractFS(src fs.FS, targetDir string) error {
|
|
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
return fs.WalkDir(src, ".", func(path string, d fs.DirEntry, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if path == "." {
|
|
return nil
|
|
}
|
|
target := filepath.Join(targetDir, path)
|
|
if d.IsDir() {
|
|
return os.MkdirAll(target, 0o755)
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
|
return err
|
|
}
|
|
r, err := src.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer r.Close()
|
|
|
|
w, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.Copy(w, r); err != nil {
|
|
_ = w.Close()
|
|
return err
|
|
}
|
|
return w.Close()
|
|
})
|
|
}
|
|
|
|
func writeAgentConfig(path string, cfg config.Config) error {
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
return err
|
|
}
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
return fmt.Errorf("create agent config: %w", err)
|
|
}
|
|
defer f.Close()
|
|
return toml.NewEncoder(f).Encode(cfg)
|
|
}
|
|
|
|
type logWriter struct {
|
|
log *slog.Logger
|
|
level slog.Level
|
|
}
|
|
|
|
func (w *logWriter) Write(p []byte) (n int, err error) {
|
|
msg := string(p)
|
|
msg = trimTrailingNewline(msg)
|
|
if msg != "" {
|
|
w.log.Log(context.Background(), w.level, msg)
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func trimTrailingNewline(s string) string {
|
|
for len(s) > 0 {
|
|
last := s[len(s)-1]
|
|
if last != '\n' && last != '\r' {
|
|
break
|
|
}
|
|
s = s[:len(s)-1]
|
|
}
|
|
return s
|
|
}
|
|
|
|
func runtimePlatform() string {
|
|
return fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
|
|
}
|
|
|
|
func agentBinaryNameForRuntime() string {
|
|
if runtime.GOOS == "windows" {
|
|
return agentBinName + ".exe"
|
|
}
|
|
return agentBinName
|
|
}
|