package runtime import ( "context" "errors" "fmt" "io" "io/fs" "log/slog" "net/http" "os" "os/exec" "path/filepath" "runtime" "sync" "syscall" "time" "github.com/BurntSushi/toml" "github.com/memohai/memoh/internal/config" "github.com/memohai/memoh/internal/embedded" ) type Manager struct { log *slog.Logger cfg config.Config host string port int workdir string cmd *exec.Cmd stopOnce sync.Once } const ( defaultGatewayHost = "127.0.0.1" defaultGatewayPort = 8081 agentConfigFileName = "config.toml" agentBinName = "agent-bin" agentUnavailableMarker = "UNAVAILABLE" healthCheckTimeout = 30 * time.Second healthCheckRetryBackoff = 400 * time.Millisecond processStopTimeout = 5 * time.Second ) func NewManager(log *slog.Logger, cfg config.Config) *Manager { host := cfg.AgentGateway.Host if host == "" { host = defaultGatewayHost } port := cfg.AgentGateway.Port if port == 0 { port = defaultGatewayPort } return &Manager{ log: log.With(slog.String("component", "agent-runtime")), cfg: cfg, host: host, port: port, } } func (m *Manager) Start(ctx context.Context) error { workdir, err := os.MkdirTemp("", "memoh-agent-runtime-*") if err != nil { return fmt.Errorf("create runtime temp dir: %w", err) } m.workdir = workdir agentFS, err := embedded.AgentFS() if err != nil { return err } agentDir := filepath.Join(workdir, "agent") if err := extractFS(agentFS, agentDir); err != nil { return fmt.Errorf("extract agent assets: %w", err) } agentBinPath := filepath.Join(agentDir, agentBinaryNameForRuntime()) if _, err := os.Stat(agentBinPath); err != nil { if errors.Is(err, os.ErrNotExist) { markerPath := filepath.Join(agentDir, agentUnavailableMarker) if _, markerErr := os.Stat(markerPath); markerErr == nil { m.log.Warn("bundled agent binary unavailable for current platform; falling back to configured agent gateway", slog.String("platform", runtimePlatform())) return nil } } return fmt.Errorf("agent binary missing: %w", err) } if err := os.Chmod(agentBinPath, 0o755); err != nil { return fmt.Errorf("chmod agent binary: %w", err) } agentConfigPath := filepath.Join(agentDir, agentConfigFileName) if err := writeAgentConfig(agentConfigPath, m.cfg); err != nil { return err } cmd := exec.Command(agentBinPath) cmd.Dir = agentDir cmd.Env = append( os.Environ(), "MEMOH_CONFIG_PATH="+agentConfigPath, "CONFIG_PATH="+agentConfigPath, ) cmd.Stdout = &logWriter{log: m.log, level: slog.LevelInfo} cmd.Stderr = &logWriter{log: m.log, level: slog.LevelError} if err := cmd.Start(); err != nil { return fmt.Errorf("start bundled agent runtime: %w", err) } m.cmd = cmd m.log.Info("bundled agent runtime started", slog.Int("pid", cmd.Process.Pid), slog.String("addr", m.address())) if err := m.waitHealthy(ctx); err != nil { return err } return nil } func (m *Manager) Stop(ctx context.Context) error { var retErr error m.stopOnce.Do(func() { if m.cmd == nil || m.cmd.Process == nil { return } _ = m.cmd.Process.Signal(os.Interrupt) done := make(chan error, 1) go func() { done <- m.cmd.Wait() }() select { case err := <-done: if err != nil && !errors.Is(err, syscall.EINTR) { retErr = err } case <-ctx.Done(): _ = m.cmd.Process.Kill() retErr = ctx.Err() case <-time.After(processStopTimeout): _ = m.cmd.Process.Kill() <-done } if m.workdir != "" { _ = os.RemoveAll(m.workdir) } }) return retErr } func (m *Manager) waitHealthy(ctx context.Context) error { client := &http.Client{Timeout: 2 * time.Second} healthURL := fmt.Sprintf("http://%s/health", m.address()) deadline := time.Now().Add(healthCheckTimeout) for time.Now().Before(deadline) { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil) resp, err := client.Do(req) if err == nil { _ = resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } } time.Sleep(healthCheckRetryBackoff) } return fmt.Errorf("bundled agent runtime health check timeout: %s", healthURL) } func (m *Manager) address() string { return fmt.Sprintf("%s:%d", m.host, m.port) } func extractFS(src fs.FS, targetDir string) error { if err := os.MkdirAll(targetDir, 0o755); err != nil { return err } return fs.WalkDir(src, ".", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } if path == "." { return nil } target := filepath.Join(targetDir, path) if d.IsDir() { return os.MkdirAll(target, 0o755) } if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { return err } r, err := src.Open(path) if err != nil { return err } defer r.Close() w, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) if err != nil { return err } if _, err := io.Copy(w, r); err != nil { _ = w.Close() return err } return w.Close() }) } func writeAgentConfig(path string, cfg config.Config) error { if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return err } f, err := os.Create(path) if err != nil { return fmt.Errorf("create agent config: %w", err) } defer f.Close() return toml.NewEncoder(f).Encode(cfg) } type logWriter struct { log *slog.Logger level slog.Level } func (w *logWriter) Write(p []byte) (n int, err error) { msg := string(p) msg = trimTrailingNewline(msg) if msg != "" { w.log.Log(context.Background(), w.level, msg) } return len(p), nil } func trimTrailingNewline(s string) string { for len(s) > 0 { last := s[len(s)-1] if last != '\n' && last != '\r' { break } s = s[:len(s)-1] } return s } func runtimePlatform() string { return fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) } func agentBinaryNameForRuntime() string { if runtime.GOOS == "windows" { return agentBinName + ".exe" } return agentBinName }