diff --git a/cmd/cli/main.go b/cmd/cli/main.go index e51e8355..45f8c675 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -1,374 +1,110 @@ package main import ( - "bufio" - "bytes" - "context" - "encoding/json" "flag" - "fmt" "io" - "log/slog" - "net/http" "os" - "strings" + "os/exec" + "runtime" + "strconv" + "sync" "time" - - "github.com/memohai/memoh/internal/chat" - "github.com/memohai/memoh/internal/config" - "github.com/memohai/memoh/internal/version" - "github.com/memohai/memoh/internal/logger" ) -type cliOptions struct { - configPath string - username string - password string - timeout time.Duration - apiBaseURL string - jwtToken string - showVersion bool -} - func main() { - opts := parseFlags() - if opts.showVersion { - fmt.Printf("Memoh CLI %s\n", version.GetInfo()) - return - } - ctx := context.Background() - - cfg, err := config.Load(opts.configPath) - if err != nil { - fmt.Fprintf(os.Stderr, "load config: %v\n", err) - os.Exit(1) - } - - logger.Init(cfg.Log.Level, cfg.Log.Format) - - if strings.TrimSpace(opts.apiBaseURL) == "" { - opts.apiBaseURL = defaultAPIBaseURL(cfg.Server.Addr) - } - if strings.TrimSpace(opts.apiBaseURL) == "" { - logger.Error("api url is required") - os.Exit(1) - } - opts.apiBaseURL = normalizeBaseURL(opts.apiBaseURL) - - jwtToken := strings.TrimSpace(opts.jwtToken) - client := &http.Client{Timeout: opts.timeout} - if jwtToken == "" { - username, password, err := resolveLoginCredentials(opts, cfg) - if err != nil { - logger.Error("resolve login", slog.Any("error", err)) - os.Exit(1) - } - loginCtx := ctx - if opts.timeout > 0 { - var cancel context.CancelFunc - loginCtx, cancel = context.WithTimeout(ctx, opts.timeout) - defer cancel() - } - jwtToken, err = resolveJWTToken(loginCtx, client, opts.apiBaseURL, username, password) - if err != nil { - logger.Error("resolve jwt", slog.Any("error", err)) - os.Exit(1) - } - } - - query := strings.TrimSpace(strings.Join(flag.Args(), " ")) - if query != "" { - if err := sendChat(ctx, client, opts.apiBaseURL, jwtToken, query); err != nil { - logger.Error("chat failed", slog.Any("error", err)) - os.Exit(1) - } - return - } - if err := runInteractive(ctx, client, opts.apiBaseURL, jwtToken); err != nil { - logger.Error("chat failed", slog.Any("error", err)) - os.Exit(1) - } -} - -func parseFlags() cliOptions { - var opts cliOptions - defaultConfig := os.Getenv("CONFIG_PATH") - if strings.TrimSpace(defaultConfig) == "" { - defaultConfig = config.DefaultConfigPath - } - - flag.StringVar(&opts.configPath, "config", defaultConfig, "Path to config.toml") - flag.StringVar(&opts.username, "username", "", "Username for login") - flag.StringVar(&opts.password, "password", "", "Password for login (or set MEMOH_PASSWORD)") - flag.StringVar(&opts.jwtToken, "jwt", "", "JWT token (optional)") - flag.StringVar(&opts.apiBaseURL, "api-url", "", "API server base URL (e.g. http://127.0.0.1:8080)") - flag.DurationVar(&opts.timeout, "timeout", 30*time.Second, "Request timeout") - flag.BoolVar(&opts.showVersion, "version", false, "Show version information") + flag.CommandLine.SetOutput(io.Discard) + containerID := flag.String("container-id", "", "") flag.Parse() - return opts -} - -func normalizeBaseURL(value string) string { - return strings.TrimRight(strings.TrimSpace(value), "/") -} - -func defaultAPIBaseURL(addr string) string { - trimmed := strings.TrimSpace(addr) - if trimmed == "" { - return "" - } - if strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") { - return normalizeBaseURL(trimmed) - } - if strings.HasPrefix(trimmed, ":") { - return "http://127.0.0.1" + trimmed - } - return "http://" + trimmed -} - -func resolveLoginCredentials(opts cliOptions, cfg config.Config) (string, string, error) { - username := strings.TrimSpace(opts.username) - if username == "" { - username = strings.TrimSpace(cfg.Admin.Username) - } - if username == "" { - return "", "", fmt.Errorf("username is required for login") + if *containerID == "" { + os.Exit(2) } - password := strings.TrimSpace(opts.password) - if password == "" { - password = strings.TrimSpace(os.Getenv("MEMOH_PASSWORD")) - } - if password == "" { - if candidate := strings.TrimSpace(cfg.Admin.Password); candidate != "" && candidate != "change-your-password-here" { - password = candidate + cmd := buildMCPCommand(*containerID) + if err := runWithStdio(cmd); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) } + os.Exit(1) } - if password == "" { - return "", "", fmt.Errorf("password is required; pass --password or set MEMOH_PASSWORD") - } - return username, password, nil } -func resolveJWTToken(ctx context.Context, client *http.Client, baseURL, username, password string) (string, error) { - resp, err := loginForToken(ctx, client, baseURL, username, password) - if err != nil { - return "", err +func buildMCPCommand(containerID string) *exec.Cmd { + execID := "mcp-" + strconv.FormatInt(time.Now().UnixNano(), 10) + if runtime.GOOS == "darwin" { + return exec.Command( + "limactl", + "shell", + "--tty=false", + "default", + "--", + "sudo", + "-n", + "ctr", + "-n", + "default", + "tasks", + "exec", + "--exec-id", + execID, + containerID, + "/mcp", + ) } - if strings.TrimSpace(resp.AccessToken) == "" { - return "", fmt.Errorf("login succeeded but token missing") - } - return resp.AccessToken, nil + return exec.Command( + "ctr", + "-n", + "default", + "tasks", + "exec", + "--exec-id", + execID, + containerID, + "/mcp", + ) } -type loginRequest struct { - Username string `json:"username"` - Password string `json:"password"` -} - -type loginResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresAt string `json:"expires_at"` - UserID string `json:"user_id"` - Username string `json:"username"` -} - -func loginForToken(ctx context.Context, client *http.Client, baseURL, username, password string) (loginResponse, error) { - body, err := json.Marshal(loginRequest{ - Username: username, - Password: password, - }) - if err != nil { - return loginResponse{}, err - } - url := normalizeBaseURL(baseURL) + "/auth/login" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { - return loginResponse{}, err - } - req.Header.Set("Content-Type", "application/json") - - resp, err := client.Do(req) - if err != nil { - return loginResponse{}, err - } - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - payload, _ := io.ReadAll(resp.Body) - return loginResponse{}, fmt.Errorf("login failed: %s", strings.TrimSpace(string(payload))) - } - - var parsed loginResponse - if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { - return loginResponse{}, err - } - return parsed, nil -} - -func runInteractive(ctx context.Context, client *http.Client, baseURL, jwtToken string) error { - reader := bufio.NewScanner(os.Stdin) - reader.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) - - fmt.Fprint(os.Stdout, "You: ") - for reader.Scan() { - line := strings.TrimSpace(reader.Text()) - if line == "" { - fmt.Fprint(os.Stdout, "You: ") - continue - } - lower := strings.ToLower(line) - if lower == "exit" || lower == "quit" { - return nil - } - if err := sendChat(ctx, client, baseURL, jwtToken, line); err != nil { - return err - } - fmt.Fprint(os.Stdout, "You: ") - } - return reader.Err() -} - -func sendChat(ctx context.Context, client *http.Client, baseURL, jwtToken, query string) error { - return streamAPIChat(ctx, client, baseURL, jwtToken, chat.ChatRequest{Query: query}) -} - -func streamAPIChat(ctx context.Context, client *http.Client, baseURL, jwtToken string, req chat.ChatRequest) error { - body, err := json.Marshal(req) +func runWithStdio(cmd *exec.Cmd) error { + stdin, err := cmd.StdinPipe() if err != nil { return err } - url := baseURL + "/chat/stream" - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + stdout, err := cmd.StdoutPipe() if err != nil { + _ = stdin.Close() return err } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "text/event-stream") - httpReq.Header.Set("Authorization", "Bearer "+jwtToken) - - resp, err := client.Do(httpReq) + stderr, err := cmd.StderrPipe() if err != nil { + _ = stdin.Close() + _ = stdout.Close() return err } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - payload, _ := io.ReadAll(resp.Body) - return fmt.Errorf("api server error: %s", strings.TrimSpace(string(payload))) - } - - scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) - - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "event:") { - continue - } - if !strings.HasPrefix(line, "data:") { - continue - } - data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if data == "" || data == "[DONE]" { - break - } - if text, ok := extractStreamText([]byte(data)); ok { - fmt.Fprint(os.Stdout, text) - } - } - - if err := scanner.Err(); err != nil { + if err := cmd.Start(); err != nil { + _ = stdin.Close() + _ = stdout.Close() + _ = stderr.Close() return err } - fmt.Fprintln(os.Stdout) - return nil -} -func renderMessageContent(raw interface{}) (string, bool) { - switch v := raw.(type) { - case string: - return v, true - case []interface{}: - parts := make([]string, 0, len(v)) - for _, item := range v { - if text, ok := renderMessageContent(item); ok && strings.TrimSpace(text) != "" { - parts = append(parts, text) - } - } - if len(parts) > 0 { - return strings.Join(parts, ""), true - } - case map[string]interface{}: - if text, ok := v["text"].(string); ok && strings.TrimSpace(text) != "" { - return text, true - } - if text, ok := v["content"].(string); ok && strings.TrimSpace(text) != "" { - return text, true - } - if kind, ok := v["type"].(string); ok && kind == "text" { - if text, ok := v["text"].(string); ok { - return text, true - } - } - } - return "", false -} + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + _, _ = io.Copy(stdin, os.Stdin) + _ = stdin.Close() + }() + go func() { + defer wg.Done() + _, _ = io.Copy(os.Stdout, stdout) + }() + go func() { + defer wg.Done() + _, _ = io.Copy(os.Stderr, stderr) + }() -func extractStreamText(raw []byte) (string, bool) { - var payload interface{} - if err := json.Unmarshal(raw, &payload); err != nil { - return "", false - } - return extractTextFromPayload(payload) -} - -func extractTextFromPayload(payload interface{}) (string, bool) { - switch v := payload.(type) { - case string: - if strings.TrimSpace(v) == "" { - return "", false - } - return v, true - case []interface{}: - parts := make([]string, 0, len(v)) - for _, item := range v { - if text, ok := extractTextFromPayload(item); ok && strings.TrimSpace(text) != "" { - parts = append(parts, text) - } - } - if len(parts) > 0 { - return strings.Join(parts, ""), true - } - case map[string]interface{}: - if text, ok := v["textDelta"].(string); ok && strings.TrimSpace(text) != "" { - return text, true - } - if text, ok := v["text"].(string); ok && strings.TrimSpace(text) != "" { - return text, true - } - if content, ok := v["content"]; ok { - if text, ok := renderMessageContent(content); ok { - return text, true - } - } - if delta, ok := v["delta"]; ok { - if text, ok := extractTextFromPayload(delta); ok { - return text, true - } - } - if data, ok := v["data"]; ok { - if text, ok := extractTextFromPayload(data); ok { - return text, true - } - } - if msg, ok := v["message"]; ok { - if text, ok := extractTextFromPayload(msg); ok { - return text, true - } - } - } - return "", false + err = cmd.Wait() + wg.Wait() + return err }