mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
Merge branch 'fix-issue-#78-bug' into v0.1
This commit is contained in:
@@ -2,7 +2,7 @@ name: Docker
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
branches: [main, "v*.*"]
|
||||
tags: ["v*"]
|
||||
paths-ignore:
|
||||
- "docs/**"
|
||||
@@ -11,7 +11,7 @@ on:
|
||||
release:
|
||||
types: [published]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
branches: [main, "v*.*"]
|
||||
paths-ignore:
|
||||
- "docs/**"
|
||||
- "**.md"
|
||||
@@ -23,7 +23,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
PUSH: ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') }}
|
||||
PUSH: ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/v') || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') }}
|
||||
REGISTRY: ghcr.io
|
||||
|
||||
permissions:
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release')
|
||||
if: github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/v') || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release')
|
||||
needs: build
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
+47
-4
@@ -3,7 +3,7 @@ import { chatModule } from './modules/chat'
|
||||
import { corsMiddleware } from './middlewares/cors'
|
||||
import { errorMiddleware } from './middlewares/error'
|
||||
import { loadConfig, getBaseUrl as getBaseUrlByConfig } from '@memoh/config'
|
||||
import { AuthFetcher } from '@memoh/agent'
|
||||
import { AgentAuthContext, AuthFetcher } from '@memoh/agent'
|
||||
|
||||
const config = loadConfig('../config.toml')
|
||||
|
||||
@@ -11,12 +11,55 @@ export const getBaseUrl = () => {
|
||||
return getBaseUrlByConfig(config)
|
||||
}
|
||||
|
||||
export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => {
|
||||
function parseJwtExp(token: string): number | null {
|
||||
try {
|
||||
const base64Url = token.split('.')[1]
|
||||
if (!base64Url) return null
|
||||
const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/')
|
||||
const jsonPayload = Buffer.from(base64, 'base64').toString('utf8')
|
||||
const payload = JSON.parse(jsonPayload)
|
||||
return payload.exp ? payload.exp * 1000 : null
|
||||
} catch (e) {
|
||||
console.error('Failed to parse JWT expiration from token', e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export const createAuthFetcher = (auth: AgentAuthContext): AuthFetcher => {
|
||||
let refreshPromise: Promise<string> | null = null
|
||||
return async (url: string, options?: RequestInit) => {
|
||||
if (auth.bearer) {
|
||||
const exp = parseJwtExp(auth.bearer)
|
||||
if (exp !== null && exp - Date.now() < 120000) { // Refresh if expiring in < 2 mins
|
||||
if (!refreshPromise) {
|
||||
refreshPromise = (async () => {
|
||||
const refreshUrl = new URL('/auth/refresh', `${getBaseUrl().replace(/\/$/, '')}/`).toString()
|
||||
const res = await fetch(refreshUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Authorization': `Bearer ${auth.bearer}` }
|
||||
})
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
return data.access_token
|
||||
}
|
||||
throw new Error('Failed to refresh token')
|
||||
})().finally(() => {
|
||||
refreshPromise = null
|
||||
})
|
||||
}
|
||||
try {
|
||||
auth.bearer = await refreshPromise
|
||||
} catch (e) {
|
||||
console.error('Token refresh failed', e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const requestOptions = options ?? {}
|
||||
const headers = new Headers(requestOptions.headers || {})
|
||||
if (bearer && !headers.has('Authorization')) {
|
||||
headers.set('Authorization', `Bearer ${bearer}`)
|
||||
if (auth.bearer && !headers.has('Authorization')) {
|
||||
headers.set('Authorization', `Bearer ${auth.bearer}`)
|
||||
}
|
||||
|
||||
const baseURL = getBaseUrl()
|
||||
|
||||
+18
-15
@@ -25,7 +25,11 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
.use(bearerMiddleware)
|
||||
.post('/', async ({ body, bearer }) => {
|
||||
console.log('chat', body)
|
||||
const authFetcher = createAuthFetcher(bearer)
|
||||
const auth = {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
}
|
||||
const authFetcher = createAuthFetcher(auth)
|
||||
const { ask } = createAgent({
|
||||
model: body.model as ModelConfig,
|
||||
activeContextTime: body.activeContextTime,
|
||||
@@ -33,10 +37,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
currentChannel: body.currentChannel,
|
||||
allowedActions: body.allowedActions,
|
||||
identity: body.identity,
|
||||
auth: {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
},
|
||||
auth,
|
||||
skills: body.usableSkills,
|
||||
mcpConnections: body.mcpConnections,
|
||||
inbox: body.inbox,
|
||||
@@ -55,7 +56,11 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
.post('/stream', async function* ({ body, bearer }) {
|
||||
console.log('stream', body)
|
||||
try {
|
||||
const authFetcher = createAuthFetcher(bearer)
|
||||
const auth = {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
}
|
||||
const authFetcher = createAuthFetcher(auth)
|
||||
const { stream } = createAgent({
|
||||
model: body.model as ModelConfig,
|
||||
activeContextTime: body.activeContextTime,
|
||||
@@ -63,10 +68,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
currentChannel: body.currentChannel,
|
||||
allowedActions: body.allowedActions,
|
||||
identity: body.identity,
|
||||
auth: {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
},
|
||||
auth,
|
||||
skills: body.usableSkills,
|
||||
mcpConnections: body.mcpConnections,
|
||||
inbox: body.inbox,
|
||||
@@ -96,17 +98,18 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
||||
})
|
||||
.post('/trigger-schedule', async ({ body, bearer }) => {
|
||||
console.log('trigger-schedule', body)
|
||||
const authFetcher = createAuthFetcher(bearer)
|
||||
const auth = {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
}
|
||||
const authFetcher = createAuthFetcher(auth)
|
||||
const { triggerSchedule } = createAgent({
|
||||
model: body.model as ModelConfig,
|
||||
activeContextTime: body.activeContextTime,
|
||||
channels: body.channels,
|
||||
currentChannel: body.currentChannel,
|
||||
identity: body.identity,
|
||||
auth: {
|
||||
bearer: bearer!,
|
||||
baseUrl: getBaseUrl(),
|
||||
},
|
||||
auth,
|
||||
skills: body.usableSkills,
|
||||
mcpConnections: body.mcpConnections,
|
||||
inbox: body.inbox,
|
||||
|
||||
@@ -159,6 +159,50 @@ func ChatTokenFromContext(c echo.Context) (ChatToken, error) {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// RefreshTokenFromContext extracts the current token from context and issues a new one
|
||||
// with the same claims but a renewed expiration time.
|
||||
func RefreshTokenFromContext(c echo.Context, secret string, defaultExpiresIn time.Duration) (string, time.Time, error) {
|
||||
token, ok := c.Get("user").(*jwt.Token)
|
||||
if !ok || token == nil || !token.Valid {
|
||||
return "", time.Time{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", time.Time{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token claims")
|
||||
}
|
||||
|
||||
// Calculate original duration if possible
|
||||
expiresIn := defaultExpiresIn
|
||||
if expRaw, ok := claims["exp"].(float64); ok {
|
||||
if iatRaw, ok := claims["iat"].(float64); ok {
|
||||
duration := time.Duration(expRaw-iatRaw) * time.Second
|
||||
if duration > 0 {
|
||||
expiresIn = duration
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(expiresIn)
|
||||
|
||||
// Create new claims, copying over existing ones but updating time bounds
|
||||
newClaims := jwt.MapClaims{}
|
||||
for k, v := range claims {
|
||||
newClaims[k] = v
|
||||
}
|
||||
newClaims["iat"] = now.Unix()
|
||||
newClaims["exp"] = expiresAt.Unix()
|
||||
|
||||
newToken := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims)
|
||||
signed, err := newToken.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
return signed, expiresAt, nil
|
||||
}
|
||||
|
||||
func claimString(claims jwt.MapClaims, key string) string {
|
||||
raw, ok := claims[key]
|
||||
if !ok || raw == nil {
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRefreshTokenFromContext(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
secret := "test-secret"
|
||||
userID := "user-123"
|
||||
|
||||
// Create an initial token with a 5-minute lifespan
|
||||
initialDuration := 5 * time.Minute
|
||||
initialTokenStr, _, err := GenerateToken(userID, secret, initialDuration)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Parse the token to place it into the echo context
|
||||
token, err := jwt.Parse(initialTokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
c.Set("user", token)
|
||||
|
||||
// Simulate some time passing to ensure the new token has a different 'iat' and 'exp'
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Run the refresh function
|
||||
defaultDuration := 1 * time.Hour
|
||||
newTokenStr, newExpiresAt, err := RefreshTokenFromContext(c, secret, defaultDuration)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, newTokenStr)
|
||||
|
||||
// Parse the original token claims for comparison
|
||||
originalClaims, ok := token.Claims.(jwt.MapClaims)
|
||||
assert.True(t, ok)
|
||||
origIat := int64(originalClaims["iat"].(float64))
|
||||
origExp := int64(originalClaims["exp"].(float64))
|
||||
|
||||
// Parse the new token
|
||||
newToken, err := jwt.Parse(newTokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, newToken.Valid)
|
||||
|
||||
newClaims, ok := newToken.Claims.(jwt.MapClaims)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Ensure standard payload claims are retained
|
||||
assert.Equal(t, userID, newClaims[claimSubject])
|
||||
assert.Equal(t, userID, newClaims[claimUserID])
|
||||
|
||||
// Validate the new time bounds
|
||||
newIat := int64(newClaims["iat"].(float64))
|
||||
newExp := int64(newClaims["exp"].(float64))
|
||||
|
||||
// 1. Ensure time has advanced
|
||||
assert.Greater(t, newIat, origIat)
|
||||
|
||||
// 2. Ensure the refreshed token has a positive lifetime and does not exceed the configured default duration
|
||||
lifetimeSeconds := newExp - newIat
|
||||
assert.Greater(t, lifetimeSeconds, int64(0))
|
||||
assert.LessOrEqual(t, lifetimeSeconds, int64(defaultDuration.Seconds()))
|
||||
|
||||
// 3. Ensure the return value matches the claim
|
||||
assert.Equal(t, newExpiresAt.Unix(), newExp)
|
||||
}
|
||||
|
||||
func TestRefreshTokenFromContext_MissingUser(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
secret := "test-secret"
|
||||
defaultDuration := 1 * time.Hour
|
||||
|
||||
// Context without the "user" key
|
||||
_, _, err := RefreshTokenFromContext(c, secret, defaultDuration)
|
||||
assert.Error(t, err)
|
||||
|
||||
httpErr, ok := err.(*echo.HTTPError)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, http.StatusUnauthorized, httpErr.Code)
|
||||
assert.Equal(t, "invalid token", httpErr.Message)
|
||||
}
|
||||
@@ -46,6 +46,7 @@ func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecre
|
||||
|
||||
func (h *AuthHandler) Register(e *echo.Echo) {
|
||||
e.POST("/auth/login", h.Login)
|
||||
e.POST("/auth/refresh", h.Refresh)
|
||||
}
|
||||
|
||||
// Login godoc
|
||||
@@ -103,3 +104,35 @@ func (h *AuthHandler) Login(c echo.Context) error {
|
||||
DisplayName: account.DisplayName,
|
||||
})
|
||||
}
|
||||
|
||||
type RefreshResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
}
|
||||
|
||||
// Refresh godoc
|
||||
// @Summary Refresh Token
|
||||
// @Description Issue a new JWT using the existing claims with updated expiration
|
||||
// @Tags auth
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} RefreshResponse
|
||||
// @Failure 401 {object} ErrorResponse
|
||||
// @Failure 500 {object} ErrorResponse
|
||||
// @Router /auth/refresh [post]
|
||||
func (h *AuthHandler) Refresh(c echo.Context) error {
|
||||
if strings.TrimSpace(h.jwtSecret) == "" {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "jwt secret not configured")
|
||||
}
|
||||
|
||||
token, expiresAt, err := auth.RefreshTokenFromContext(c, h.jwtSecret, h.expiresIn)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, err.Error())
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, RefreshResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user