feat: mcp transpipe cli

This commit is contained in:
Ran
2026-02-03 02:11:32 +08:00
parent ddd853fee4
commit d775c6df6d
+74 -338
View File
@@ -1,374 +1,110 @@
package main
import (
"bufio"
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"os/exec"
"runtime"
"strconv"
"sync"
"time"
"github.com/memohai/memoh/internal/chat"
"github.com/memohai/memoh/internal/config"
"github.com/memohai/memoh/internal/version"
"github.com/memohai/memoh/internal/logger"
)
type cliOptions struct {
configPath string
username string
password string
timeout time.Duration
apiBaseURL string
jwtToken string
showVersion bool
}
func main() {
opts := parseFlags()
if opts.showVersion {
fmt.Printf("Memoh CLI %s\n", version.GetInfo())
return
}
ctx := context.Background()
cfg, err := config.Load(opts.configPath)
if err != nil {
fmt.Fprintf(os.Stderr, "load config: %v\n", err)
os.Exit(1)
}
logger.Init(cfg.Log.Level, cfg.Log.Format)
if strings.TrimSpace(opts.apiBaseURL) == "" {
opts.apiBaseURL = defaultAPIBaseURL(cfg.Server.Addr)
}
if strings.TrimSpace(opts.apiBaseURL) == "" {
logger.Error("api url is required")
os.Exit(1)
}
opts.apiBaseURL = normalizeBaseURL(opts.apiBaseURL)
jwtToken := strings.TrimSpace(opts.jwtToken)
client := &http.Client{Timeout: opts.timeout}
if jwtToken == "" {
username, password, err := resolveLoginCredentials(opts, cfg)
if err != nil {
logger.Error("resolve login", slog.Any("error", err))
os.Exit(1)
}
loginCtx := ctx
if opts.timeout > 0 {
var cancel context.CancelFunc
loginCtx, cancel = context.WithTimeout(ctx, opts.timeout)
defer cancel()
}
jwtToken, err = resolveJWTToken(loginCtx, client, opts.apiBaseURL, username, password)
if err != nil {
logger.Error("resolve jwt", slog.Any("error", err))
os.Exit(1)
}
}
query := strings.TrimSpace(strings.Join(flag.Args(), " "))
if query != "" {
if err := sendChat(ctx, client, opts.apiBaseURL, jwtToken, query); err != nil {
logger.Error("chat failed", slog.Any("error", err))
os.Exit(1)
}
return
}
if err := runInteractive(ctx, client, opts.apiBaseURL, jwtToken); err != nil {
logger.Error("chat failed", slog.Any("error", err))
os.Exit(1)
}
}
func parseFlags() cliOptions {
var opts cliOptions
defaultConfig := os.Getenv("CONFIG_PATH")
if strings.TrimSpace(defaultConfig) == "" {
defaultConfig = config.DefaultConfigPath
}
flag.StringVar(&opts.configPath, "config", defaultConfig, "Path to config.toml")
flag.StringVar(&opts.username, "username", "", "Username for login")
flag.StringVar(&opts.password, "password", "", "Password for login (or set MEMOH_PASSWORD)")
flag.StringVar(&opts.jwtToken, "jwt", "", "JWT token (optional)")
flag.StringVar(&opts.apiBaseURL, "api-url", "", "API server base URL (e.g. http://127.0.0.1:8080)")
flag.DurationVar(&opts.timeout, "timeout", 30*time.Second, "Request timeout")
flag.BoolVar(&opts.showVersion, "version", false, "Show version information")
flag.CommandLine.SetOutput(io.Discard)
containerID := flag.String("container-id", "", "")
flag.Parse()
return opts
}
func normalizeBaseURL(value string) string {
return strings.TrimRight(strings.TrimSpace(value), "/")
}
func defaultAPIBaseURL(addr string) string {
trimmed := strings.TrimSpace(addr)
if trimmed == "" {
return ""
}
if strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") {
return normalizeBaseURL(trimmed)
}
if strings.HasPrefix(trimmed, ":") {
return "http://127.0.0.1" + trimmed
}
return "http://" + trimmed
}
func resolveLoginCredentials(opts cliOptions, cfg config.Config) (string, string, error) {
username := strings.TrimSpace(opts.username)
if username == "" {
username = strings.TrimSpace(cfg.Admin.Username)
}
if username == "" {
return "", "", fmt.Errorf("username is required for login")
if *containerID == "" {
os.Exit(2)
}
password := strings.TrimSpace(opts.password)
if password == "" {
password = strings.TrimSpace(os.Getenv("MEMOH_PASSWORD"))
}
if password == "" {
if candidate := strings.TrimSpace(cfg.Admin.Password); candidate != "" && candidate != "change-your-password-here" {
password = candidate
cmd := buildMCPCommand(*containerID)
if err := runWithStdio(cmd); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
if password == "" {
return "", "", fmt.Errorf("password is required; pass --password or set MEMOH_PASSWORD")
}
return username, password, nil
}
func resolveJWTToken(ctx context.Context, client *http.Client, baseURL, username, password string) (string, error) {
resp, err := loginForToken(ctx, client, baseURL, username, password)
if err != nil {
return "", err
func buildMCPCommand(containerID string) *exec.Cmd {
execID := "mcp-" + strconv.FormatInt(time.Now().UnixNano(), 10)
if runtime.GOOS == "darwin" {
return exec.Command(
"limactl",
"shell",
"--tty=false",
"default",
"--",
"sudo",
"-n",
"ctr",
"-n",
"default",
"tasks",
"exec",
"--exec-id",
execID,
containerID,
"/mcp",
)
}
if strings.TrimSpace(resp.AccessToken) == "" {
return "", fmt.Errorf("login succeeded but token missing")
}
return resp.AccessToken, nil
return exec.Command(
"ctr",
"-n",
"default",
"tasks",
"exec",
"--exec-id",
execID,
containerID,
"/mcp",
)
}
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type loginResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresAt string `json:"expires_at"`
UserID string `json:"user_id"`
Username string `json:"username"`
}
func loginForToken(ctx context.Context, client *http.Client, baseURL, username, password string) (loginResponse, error) {
body, err := json.Marshal(loginRequest{
Username: username,
Password: password,
})
if err != nil {
return loginResponse{}, err
}
url := normalizeBaseURL(baseURL) + "/auth/login"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return loginResponse{}, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return loginResponse{}, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
payload, _ := io.ReadAll(resp.Body)
return loginResponse{}, fmt.Errorf("login failed: %s", strings.TrimSpace(string(payload)))
}
var parsed loginResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return loginResponse{}, err
}
return parsed, nil
}
func runInteractive(ctx context.Context, client *http.Client, baseURL, jwtToken string) error {
reader := bufio.NewScanner(os.Stdin)
reader.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
fmt.Fprint(os.Stdout, "You: ")
for reader.Scan() {
line := strings.TrimSpace(reader.Text())
if line == "" {
fmt.Fprint(os.Stdout, "You: ")
continue
}
lower := strings.ToLower(line)
if lower == "exit" || lower == "quit" {
return nil
}
if err := sendChat(ctx, client, baseURL, jwtToken, line); err != nil {
return err
}
fmt.Fprint(os.Stdout, "You: ")
}
return reader.Err()
}
func sendChat(ctx context.Context, client *http.Client, baseURL, jwtToken, query string) error {
return streamAPIChat(ctx, client, baseURL, jwtToken, chat.ChatRequest{Query: query})
}
func streamAPIChat(ctx context.Context, client *http.Client, baseURL, jwtToken string, req chat.ChatRequest) error {
body, err := json.Marshal(req)
func runWithStdio(cmd *exec.Cmd) error {
stdin, err := cmd.StdinPipe()
if err != nil {
return err
}
url := baseURL + "/chat/stream"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
stdout, err := cmd.StdoutPipe()
if err != nil {
_ = stdin.Close()
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Authorization", "Bearer "+jwtToken)
resp, err := client.Do(httpReq)
stderr, err := cmd.StderrPipe()
if err != nil {
_ = stdin.Close()
_ = stdout.Close()
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
payload, _ := io.ReadAll(resp.Body)
return fmt.Errorf("api server error: %s", strings.TrimSpace(string(payload)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "event:") {
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
break
}
if text, ok := extractStreamText([]byte(data)); ok {
fmt.Fprint(os.Stdout, text)
}
}
if err := scanner.Err(); err != nil {
if err := cmd.Start(); err != nil {
_ = stdin.Close()
_ = stdout.Close()
_ = stderr.Close()
return err
}
fmt.Fprintln(os.Stdout)
return nil
}
func renderMessageContent(raw interface{}) (string, bool) {
switch v := raw.(type) {
case string:
return v, true
case []interface{}:
parts := make([]string, 0, len(v))
for _, item := range v {
if text, ok := renderMessageContent(item); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
if len(parts) > 0 {
return strings.Join(parts, ""), true
}
case map[string]interface{}:
if text, ok := v["text"].(string); ok && strings.TrimSpace(text) != "" {
return text, true
}
if text, ok := v["content"].(string); ok && strings.TrimSpace(text) != "" {
return text, true
}
if kind, ok := v["type"].(string); ok && kind == "text" {
if text, ok := v["text"].(string); ok {
return text, true
}
}
}
return "", false
}
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
_, _ = io.Copy(stdin, os.Stdin)
_ = stdin.Close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(os.Stdout, stdout)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(os.Stderr, stderr)
}()
func extractStreamText(raw []byte) (string, bool) {
var payload interface{}
if err := json.Unmarshal(raw, &payload); err != nil {
return "", false
}
return extractTextFromPayload(payload)
}
func extractTextFromPayload(payload interface{}) (string, bool) {
switch v := payload.(type) {
case string:
if strings.TrimSpace(v) == "" {
return "", false
}
return v, true
case []interface{}:
parts := make([]string, 0, len(v))
for _, item := range v {
if text, ok := extractTextFromPayload(item); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
if len(parts) > 0 {
return strings.Join(parts, ""), true
}
case map[string]interface{}:
if text, ok := v["textDelta"].(string); ok && strings.TrimSpace(text) != "" {
return text, true
}
if text, ok := v["text"].(string); ok && strings.TrimSpace(text) != "" {
return text, true
}
if content, ok := v["content"]; ok {
if text, ok := renderMessageContent(content); ok {
return text, true
}
}
if delta, ok := v["delta"]; ok {
if text, ok := extractTextFromPayload(delta); ok {
return text, true
}
}
if data, ok := v["data"]; ok {
if text, ok := extractTextFromPayload(data); ok {
return text, true
}
}
if msg, ok := v["message"]; ok {
if text, ok := extractTextFromPayload(msg); ok {
return text, true
}
}
}
return "", false
err = cmd.Wait()
wg.Wait()
return err
}