fix(mail): callback URL for Gmail OAuth (#303)

This commit is contained in:
晨苒
2026-03-29 16:29:34 +08:00
committed by GitHub
parent f554eee20b
commit 0b56fb0bf7
4 changed files with 93 additions and 8 deletions
+1 -1
View File
@@ -786,7 +786,7 @@ func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, token
if strings.HasPrefix(host, ":") {
host = "localhost" + host
}
callbackURL := "http://" + host + "/email/oauth/callback"
callbackURL := "http://" + host + "/api/email/oauth/callback"
return handlers.NewEmailOAuthHandler(log, service, tokenStore, callbackURL)
}
+1 -1
View File
@@ -901,7 +901,7 @@ func provideEmailOAuthHandler(log *slog.Logger, service *emailpkg.Service, token
if strings.HasPrefix(host, ":") {
host = "localhost" + host
}
callbackURL := "http://" + host + "/email/oauth/callback"
callbackURL := "http://" + host + "/api/email/oauth/callback"
return handlers.NewEmailOAuthHandler(log, service, tokenStore, callbackURL)
}
+90 -5
View File
@@ -2,10 +2,12 @@ package handlers
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"errors"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
@@ -16,6 +18,8 @@ import (
emailgmail "github.com/memohai/memoh/internal/email/adapters/gmail"
)
const emailOAuthCallbackPath = "/api/email/oauth/callback"
// EmailOAuthHandler handles the OAuth2 authorization flow for Gmail providers.
type EmailOAuthHandler struct {
service *email.Service
@@ -47,6 +51,7 @@ func (h *EmailOAuthHandler) Register(e *echo.Echo) {
e.GET("/email-providers/:id/oauth/status", h.Status)
e.DELETE("/email-providers/:id/oauth/token", h.Revoke)
e.GET("/email/oauth/callback", h.Callback)
e.GET(emailOAuthCallbackPath, h.Callback)
}
// Authorize godoc
@@ -69,7 +74,8 @@ func (h *EmailOAuthHandler) Authorize(c echo.Context) error {
return echo.NewHTTPError(http.StatusNotFound, "provider not found")
}
state, err := generateState()
callbackURL := h.effectiveCallbackURL(c)
state, err := generateState(callbackURL)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to generate state")
}
@@ -85,7 +91,7 @@ func (h *EmailOAuthHandler) Authorize(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "client_id is not configured for this provider")
}
adapter := emailgmail.New(h.logger, h.tokenStore)
authURL = adapter.AuthorizeURL(clientID, h.callbackURL, state)
authURL = adapter.AuthorizeURL(clientID, callbackURL, state)
}
if authURL == "" {
return echo.NewHTTPError(http.StatusBadRequest, "provider does not support OAuth2")
@@ -132,7 +138,11 @@ func (h *EmailOAuthHandler) Callback(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "provider does not support OAuth2")
}
adapter := emailgmail.New(h.logger, h.tokenStore)
if err := adapter.ExchangeCode(ctx, provider.Config, stored.ProviderID, code, h.callbackURL); err != nil {
redirectURI := callbackURLFromState(state)
if redirectURI == "" {
redirectURI = h.effectiveCallbackURL(c)
}
if err := adapter.ExchangeCode(ctx, provider.Config, stored.ProviderID, code, redirectURI); err != nil {
h.logger.Error("gmail code exchange failed", slog.Any("error", err))
return echo.NewHTTPError(http.StatusInternalServerError, "token exchange failed")
}
@@ -236,10 +246,85 @@ func isProviderConfigured(provider email.ProviderResponse) bool {
return strings.TrimSpace(clientID) != ""
}
func generateState() (string, error) {
func generateState(callbackURL string) (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
state := hex.EncodeToString(b)
if callbackURL == "" {
return state, nil
}
return state + "." + base64.RawURLEncoding.EncodeToString([]byte(callbackURL)), nil
}
func (h *EmailOAuthHandler) effectiveCallbackURL(c echo.Context) string {
if baseURL := requestBaseURL(c.Request()); baseURL != "" {
return strings.TrimRight(baseURL, "/") + emailOAuthCallbackPath
}
return h.callbackURL
}
func requestBaseURL(req *http.Request) string {
if origin := normalizeOrigin(req.Header.Get(echo.HeaderOrigin)); origin != "" {
return origin
}
if referer := normalizeOrigin(req.Referer()); referer != "" {
return referer
}
host := firstHeaderValue(req.Header.Get("X-Forwarded-Host"))
if host == "" {
host = strings.TrimSpace(req.Host)
}
if host == "" {
return ""
}
proto := firstHeaderValue(req.Header.Get(echo.HeaderXForwardedProto))
if proto == "" {
if req.TLS != nil {
proto = "https"
} else {
proto = "http"
}
}
if port := firstHeaderValue(req.Header.Get("X-Forwarded-Port")); port != "" &&
!strings.Contains(host, ":") &&
(proto != "https" || port != "443") &&
(proto != "http" || port != "80") {
host += ":" + port
}
return proto + "://" + host
}
func normalizeOrigin(raw string) string {
origin := firstHeaderValue(raw)
if origin == "" {
return ""
}
parsed, err := url.Parse(origin)
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
return ""
}
return parsed.Scheme + "://" + parsed.Host
}
func firstHeaderValue(raw string) string {
if raw == "" {
return ""
}
parts := strings.Split(raw, ",")
return strings.TrimSpace(parts[0])
}
func callbackURLFromState(state string) string {
_, encoded, ok := strings.Cut(state, ".")
if !ok || encoded == "" {
return ""
}
callbackURL, err := base64.RawURLEncoding.DecodeString(encoded)
if err != nil {
return ""
}
return strings.TrimSpace(string(callbackURL))
}
+1 -1
View File
@@ -87,7 +87,7 @@ func shouldSkipJWT(path string) bool {
if strings.HasPrefix(path, "/email/mailgun/webhook/") {
return true
}
if strings.HasPrefix(path, "/email/oauth/callback") {
if strings.HasPrefix(path, "/email/oauth/callback") || strings.HasPrefix(path, "/api/email/oauth/callback") {
return true
}
if strings.HasPrefix(path, "/providers/oauth/callback") {