diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 53f765b9..d22b5e1f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -2,7 +2,7 @@ name: Docker on: push: - branches: [main] + branches: [main, "v*.*"] tags: ["v*"] paths-ignore: - "docs/**" @@ -11,7 +11,7 @@ on: release: types: [published] pull_request: - branches: [main] + branches: [main, "v*.*"] paths-ignore: - "docs/**" - "**.md" @@ -23,7 +23,7 @@ concurrency: cancel-in-progress: true env: - PUSH: ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') }} + PUSH: ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/v') || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') }} REGISTRY: ghcr.io permissions: @@ -126,7 +126,7 @@ jobs: merge: runs-on: ubuntu-latest - if: github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') + if: github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/v') || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') needs: build strategy: matrix: diff --git a/agent/src/index.ts b/agent/src/index.ts index 4ae5c75d..8a82c950 100644 --- a/agent/src/index.ts +++ b/agent/src/index.ts @@ -3,7 +3,7 @@ import { chatModule } from './modules/chat' import { corsMiddleware } from './middlewares/cors' import { errorMiddleware } from './middlewares/error' import { loadConfig, getBaseUrl as getBaseUrlByConfig } from '@memoh/config' -import { AuthFetcher } from '@memoh/agent' +import { AgentAuthContext, AuthFetcher } from '@memoh/agent' const config = loadConfig('../config.toml') @@ -11,12 +11,55 @@ export const getBaseUrl = () => { return getBaseUrlByConfig(config) } -export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => { +function parseJwtExp(token: string): number | null { + try { + const base64Url = token.split('.')[1] + if (!base64Url) return null + const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/') + const jsonPayload = Buffer.from(base64, 'base64').toString('utf8') + const payload = JSON.parse(jsonPayload) + return payload.exp ? payload.exp * 1000 : null + } catch (e) { + console.error('Failed to parse JWT expiration from token', e) + return null + } +} + +export const createAuthFetcher = (auth: AgentAuthContext): AuthFetcher => { + let refreshPromise: Promise | null = null return async (url: string, options?: RequestInit) => { + if (auth.bearer) { + const exp = parseJwtExp(auth.bearer) + if (exp !== null && exp - Date.now() < 120000) { // Refresh if expiring in < 2 mins + if (!refreshPromise) { + refreshPromise = (async () => { + const refreshUrl = new URL('/auth/refresh', `${getBaseUrl().replace(/\/$/, '')}/`).toString() + const res = await fetch(refreshUrl, { + method: 'POST', + headers: { 'Authorization': `Bearer ${auth.bearer}` } + }) + if (res.ok) { + const data = await res.json() + return data.access_token + } + throw new Error('Failed to refresh token') + })().finally(() => { + refreshPromise = null + }) + } + try { + auth.bearer = await refreshPromise + } catch (e) { + console.error('Token refresh failed', e) + throw e + } + } + } + const requestOptions = options ?? {} const headers = new Headers(requestOptions.headers || {}) - if (bearer && !headers.has('Authorization')) { - headers.set('Authorization', `Bearer ${bearer}`) + if (auth.bearer && !headers.has('Authorization')) { + headers.set('Authorization', `Bearer ${auth.bearer}`) } const baseURL = getBaseUrl() diff --git a/agent/src/modules/chat.ts b/agent/src/modules/chat.ts index 4758ba91..a70a8c4f 100644 --- a/agent/src/modules/chat.ts +++ b/agent/src/modules/chat.ts @@ -25,7 +25,11 @@ export const chatModule = new Elysia({ prefix: '/chat' }) .use(bearerMiddleware) .post('/', async ({ body, bearer }) => { console.log('chat', body) - const authFetcher = createAuthFetcher(bearer) + const auth = { + bearer: bearer!, + baseUrl: getBaseUrl(), + } + const authFetcher = createAuthFetcher(auth) const { ask } = createAgent({ model: body.model as ModelConfig, activeContextTime: body.activeContextTime, @@ -33,10 +37,7 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentChannel: body.currentChannel, allowedActions: body.allowedActions, identity: body.identity, - auth: { - bearer: bearer!, - baseUrl: getBaseUrl(), - }, + auth, skills: body.usableSkills, mcpConnections: body.mcpConnections, inbox: body.inbox, @@ -55,7 +56,11 @@ export const chatModule = new Elysia({ prefix: '/chat' }) .post('/stream', async function* ({ body, bearer }) { console.log('stream', body) try { - const authFetcher = createAuthFetcher(bearer) + const auth = { + bearer: bearer!, + baseUrl: getBaseUrl(), + } + const authFetcher = createAuthFetcher(auth) const { stream } = createAgent({ model: body.model as ModelConfig, activeContextTime: body.activeContextTime, @@ -63,10 +68,7 @@ export const chatModule = new Elysia({ prefix: '/chat' }) currentChannel: body.currentChannel, allowedActions: body.allowedActions, identity: body.identity, - auth: { - bearer: bearer!, - baseUrl: getBaseUrl(), - }, + auth, skills: body.usableSkills, mcpConnections: body.mcpConnections, inbox: body.inbox, @@ -96,17 +98,18 @@ export const chatModule = new Elysia({ prefix: '/chat' }) }) .post('/trigger-schedule', async ({ body, bearer }) => { console.log('trigger-schedule', body) - const authFetcher = createAuthFetcher(bearer) + const auth = { + bearer: bearer!, + baseUrl: getBaseUrl(), + } + const authFetcher = createAuthFetcher(auth) const { triggerSchedule } = createAgent({ model: body.model as ModelConfig, activeContextTime: body.activeContextTime, channels: body.channels, currentChannel: body.currentChannel, identity: body.identity, - auth: { - bearer: bearer!, - baseUrl: getBaseUrl(), - }, + auth, skills: body.usableSkills, mcpConnections: body.mcpConnections, inbox: body.inbox, diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 86854928..340cab23 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -159,6 +159,50 @@ func ChatTokenFromContext(c echo.Context) (ChatToken, error) { return info, nil } +// RefreshTokenFromContext extracts the current token from context and issues a new one +// with the same claims but a renewed expiration time. +func RefreshTokenFromContext(c echo.Context, secret string, defaultExpiresIn time.Duration) (string, time.Time, error) { + token, ok := c.Get("user").(*jwt.Token) + if !ok || token == nil || !token.Valid { + return "", time.Time{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return "", time.Time{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token claims") + } + + // Calculate original duration if possible + expiresIn := defaultExpiresIn + if expRaw, ok := claims["exp"].(float64); ok { + if iatRaw, ok := claims["iat"].(float64); ok { + duration := time.Duration(expRaw-iatRaw) * time.Second + if duration > 0 { + expiresIn = duration + } + } + } + + now := time.Now().UTC() + expiresAt := now.Add(expiresIn) + + // Create new claims, copying over existing ones but updating time bounds + newClaims := jwt.MapClaims{} + for k, v := range claims { + newClaims[k] = v + } + newClaims["iat"] = now.Unix() + newClaims["exp"] = expiresAt.Unix() + + newToken := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims) + signed, err := newToken.SignedString([]byte(secret)) + if err != nil { + return "", time.Time{}, err + } + + return signed, expiresAt, nil +} + func claimString(claims jwt.MapClaims, key string) string { raw, ok := claims[key] if !ok || raw == nil { diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go new file mode 100644 index 00000000..5ccbf8ba --- /dev/null +++ b/internal/auth/jwt_test.go @@ -0,0 +1,97 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestRefreshTokenFromContext(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + secret := "test-secret" + userID := "user-123" + + // Create an initial token with a 5-minute lifespan + initialDuration := 5 * time.Minute + initialTokenStr, _, err := GenerateToken(userID, secret, initialDuration) + assert.NoError(t, err) + + // Parse the token to place it into the echo context + token, err := jwt.Parse(initialTokenStr, func(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + assert.NoError(t, err) + c.Set("user", token) + + // Simulate some time passing to ensure the new token has a different 'iat' and 'exp' + time.Sleep(1 * time.Second) + + // Run the refresh function + defaultDuration := 1 * time.Hour + newTokenStr, newExpiresAt, err := RefreshTokenFromContext(c, secret, defaultDuration) + assert.NoError(t, err) + assert.NotEmpty(t, newTokenStr) + + // Parse the original token claims for comparison + originalClaims, ok := token.Claims.(jwt.MapClaims) + assert.True(t, ok) + origIat := int64(originalClaims["iat"].(float64)) + origExp := int64(originalClaims["exp"].(float64)) + + // Parse the new token + newToken, err := jwt.Parse(newTokenStr, func(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + assert.NoError(t, err) + assert.True(t, newToken.Valid) + + newClaims, ok := newToken.Claims.(jwt.MapClaims) + assert.True(t, ok) + + // Ensure standard payload claims are retained + assert.Equal(t, userID, newClaims[claimSubject]) + assert.Equal(t, userID, newClaims[claimUserID]) + + // Validate the new time bounds + newIat := int64(newClaims["iat"].(float64)) + newExp := int64(newClaims["exp"].(float64)) + + // 1. Ensure time has advanced + assert.Greater(t, newIat, origIat) + + // 2. Ensure the refreshed token has a positive lifetime and does not exceed the configured default duration + lifetimeSeconds := newExp - newIat + assert.Greater(t, lifetimeSeconds, int64(0)) + assert.LessOrEqual(t, lifetimeSeconds, int64(defaultDuration.Seconds())) + + // 3. Ensure the return value matches the claim + assert.Equal(t, newExpiresAt.Unix(), newExp) +} + +func TestRefreshTokenFromContext_MissingUser(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + secret := "test-secret" + defaultDuration := 1 * time.Hour + + // Context without the "user" key + _, _, err := RefreshTokenFromContext(c, secret, defaultDuration) + assert.Error(t, err) + + httpErr, ok := err.(*echo.HTTPError) + assert.True(t, ok) + assert.Equal(t, http.StatusUnauthorized, httpErr.Code) + assert.Equal(t, "invalid token", httpErr.Message) +} diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index dfb4503d..d3deb6f8 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -46,6 +46,7 @@ func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecre func (h *AuthHandler) Register(e *echo.Echo) { e.POST("/auth/login", h.Login) + e.POST("/auth/refresh", h.Refresh) } // Login godoc @@ -103,3 +104,35 @@ func (h *AuthHandler) Login(c echo.Context) error { DisplayName: account.DisplayName, }) } + +type RefreshResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresAt string `json:"expires_at"` +} + +// Refresh godoc +// @Summary Refresh Token +// @Description Issue a new JWT using the existing claims with updated expiration +// @Tags auth +// @Security BearerAuth +// @Success 200 {object} RefreshResponse +// @Failure 401 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /auth/refresh [post] +func (h *AuthHandler) Refresh(c echo.Context) error { + if strings.TrimSpace(h.jwtSecret) == "" { + return echo.NewHTTPError(http.StatusInternalServerError, "jwt secret not configured") + } + + token, expiresAt, err := auth.RefreshTokenFromContext(c, h.jwtSecret, h.expiresIn) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, err.Error()) + } + + return c.JSON(http.StatusOK, RefreshResponse{ + AccessToken: token, + TokenType: "Bearer", + ExpiresAt: expiresAt.Format(time.RFC3339), + }) +}