mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
Merge branch 'fix-issue-#78-bug' into v0.1
This commit is contained in:
@@ -2,7 +2,7 @@ name: Docker
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main, "v*.*"]
|
||||||
tags: ["v*"]
|
tags: ["v*"]
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- "docs/**"
|
- "docs/**"
|
||||||
@@ -11,7 +11,7 @@ on:
|
|||||||
release:
|
release:
|
||||||
types: [published]
|
types: [published]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main, "v*.*"]
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- "docs/**"
|
- "docs/**"
|
||||||
- "**.md"
|
- "**.md"
|
||||||
@@ -23,7 +23,7 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
PUSH: ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') }}
|
PUSH: ${{ github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/v') || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release') }}
|
||||||
REGISTRY: ghcr.io
|
REGISTRY: ghcr.io
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
merge:
|
merge:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release')
|
if: github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/v') || startsWith(github.ref, 'refs/tags/') || github.event_name == 'release')
|
||||||
needs: build
|
needs: build
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
+47
-4
@@ -3,7 +3,7 @@ import { chatModule } from './modules/chat'
|
|||||||
import { corsMiddleware } from './middlewares/cors'
|
import { corsMiddleware } from './middlewares/cors'
|
||||||
import { errorMiddleware } from './middlewares/error'
|
import { errorMiddleware } from './middlewares/error'
|
||||||
import { loadConfig, getBaseUrl as getBaseUrlByConfig } from '@memoh/config'
|
import { loadConfig, getBaseUrl as getBaseUrlByConfig } from '@memoh/config'
|
||||||
import { AuthFetcher } from '@memoh/agent'
|
import { AgentAuthContext, AuthFetcher } from '@memoh/agent'
|
||||||
|
|
||||||
const config = loadConfig('../config.toml')
|
const config = loadConfig('../config.toml')
|
||||||
|
|
||||||
@@ -11,12 +11,55 @@ export const getBaseUrl = () => {
|
|||||||
return getBaseUrlByConfig(config)
|
return getBaseUrlByConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
export const createAuthFetcher = (bearer: string | undefined): AuthFetcher => {
|
function parseJwtExp(token: string): number | null {
|
||||||
|
try {
|
||||||
|
const base64Url = token.split('.')[1]
|
||||||
|
if (!base64Url) return null
|
||||||
|
const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/')
|
||||||
|
const jsonPayload = Buffer.from(base64, 'base64').toString('utf8')
|
||||||
|
const payload = JSON.parse(jsonPayload)
|
||||||
|
return payload.exp ? payload.exp * 1000 : null
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to parse JWT expiration from token', e)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const createAuthFetcher = (auth: AgentAuthContext): AuthFetcher => {
|
||||||
|
let refreshPromise: Promise<string> | null = null
|
||||||
return async (url: string, options?: RequestInit) => {
|
return async (url: string, options?: RequestInit) => {
|
||||||
|
if (auth.bearer) {
|
||||||
|
const exp = parseJwtExp(auth.bearer)
|
||||||
|
if (exp !== null && exp - Date.now() < 120000) { // Refresh if expiring in < 2 mins
|
||||||
|
if (!refreshPromise) {
|
||||||
|
refreshPromise = (async () => {
|
||||||
|
const refreshUrl = new URL('/auth/refresh', `${getBaseUrl().replace(/\/$/, '')}/`).toString()
|
||||||
|
const res = await fetch(refreshUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Authorization': `Bearer ${auth.bearer}` }
|
||||||
|
})
|
||||||
|
if (res.ok) {
|
||||||
|
const data = await res.json()
|
||||||
|
return data.access_token
|
||||||
|
}
|
||||||
|
throw new Error('Failed to refresh token')
|
||||||
|
})().finally(() => {
|
||||||
|
refreshPromise = null
|
||||||
|
})
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
auth.bearer = await refreshPromise
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Token refresh failed', e)
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const requestOptions = options ?? {}
|
const requestOptions = options ?? {}
|
||||||
const headers = new Headers(requestOptions.headers || {})
|
const headers = new Headers(requestOptions.headers || {})
|
||||||
if (bearer && !headers.has('Authorization')) {
|
if (auth.bearer && !headers.has('Authorization')) {
|
||||||
headers.set('Authorization', `Bearer ${bearer}`)
|
headers.set('Authorization', `Bearer ${auth.bearer}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
const baseURL = getBaseUrl()
|
const baseURL = getBaseUrl()
|
||||||
|
|||||||
+18
-15
@@ -25,7 +25,11 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
|||||||
.use(bearerMiddleware)
|
.use(bearerMiddleware)
|
||||||
.post('/', async ({ body, bearer }) => {
|
.post('/', async ({ body, bearer }) => {
|
||||||
console.log('chat', body)
|
console.log('chat', body)
|
||||||
const authFetcher = createAuthFetcher(bearer)
|
const auth = {
|
||||||
|
bearer: bearer!,
|
||||||
|
baseUrl: getBaseUrl(),
|
||||||
|
}
|
||||||
|
const authFetcher = createAuthFetcher(auth)
|
||||||
const { ask } = createAgent({
|
const { ask } = createAgent({
|
||||||
model: body.model as ModelConfig,
|
model: body.model as ModelConfig,
|
||||||
activeContextTime: body.activeContextTime,
|
activeContextTime: body.activeContextTime,
|
||||||
@@ -33,10 +37,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
|||||||
currentChannel: body.currentChannel,
|
currentChannel: body.currentChannel,
|
||||||
allowedActions: body.allowedActions,
|
allowedActions: body.allowedActions,
|
||||||
identity: body.identity,
|
identity: body.identity,
|
||||||
auth: {
|
auth,
|
||||||
bearer: bearer!,
|
|
||||||
baseUrl: getBaseUrl(),
|
|
||||||
},
|
|
||||||
skills: body.usableSkills,
|
skills: body.usableSkills,
|
||||||
mcpConnections: body.mcpConnections,
|
mcpConnections: body.mcpConnections,
|
||||||
inbox: body.inbox,
|
inbox: body.inbox,
|
||||||
@@ -55,7 +56,11 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
|||||||
.post('/stream', async function* ({ body, bearer }) {
|
.post('/stream', async function* ({ body, bearer }) {
|
||||||
console.log('stream', body)
|
console.log('stream', body)
|
||||||
try {
|
try {
|
||||||
const authFetcher = createAuthFetcher(bearer)
|
const auth = {
|
||||||
|
bearer: bearer!,
|
||||||
|
baseUrl: getBaseUrl(),
|
||||||
|
}
|
||||||
|
const authFetcher = createAuthFetcher(auth)
|
||||||
const { stream } = createAgent({
|
const { stream } = createAgent({
|
||||||
model: body.model as ModelConfig,
|
model: body.model as ModelConfig,
|
||||||
activeContextTime: body.activeContextTime,
|
activeContextTime: body.activeContextTime,
|
||||||
@@ -63,10 +68,7 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
|||||||
currentChannel: body.currentChannel,
|
currentChannel: body.currentChannel,
|
||||||
allowedActions: body.allowedActions,
|
allowedActions: body.allowedActions,
|
||||||
identity: body.identity,
|
identity: body.identity,
|
||||||
auth: {
|
auth,
|
||||||
bearer: bearer!,
|
|
||||||
baseUrl: getBaseUrl(),
|
|
||||||
},
|
|
||||||
skills: body.usableSkills,
|
skills: body.usableSkills,
|
||||||
mcpConnections: body.mcpConnections,
|
mcpConnections: body.mcpConnections,
|
||||||
inbox: body.inbox,
|
inbox: body.inbox,
|
||||||
@@ -96,17 +98,18 @@ export const chatModule = new Elysia({ prefix: '/chat' })
|
|||||||
})
|
})
|
||||||
.post('/trigger-schedule', async ({ body, bearer }) => {
|
.post('/trigger-schedule', async ({ body, bearer }) => {
|
||||||
console.log('trigger-schedule', body)
|
console.log('trigger-schedule', body)
|
||||||
const authFetcher = createAuthFetcher(bearer)
|
const auth = {
|
||||||
|
bearer: bearer!,
|
||||||
|
baseUrl: getBaseUrl(),
|
||||||
|
}
|
||||||
|
const authFetcher = createAuthFetcher(auth)
|
||||||
const { triggerSchedule } = createAgent({
|
const { triggerSchedule } = createAgent({
|
||||||
model: body.model as ModelConfig,
|
model: body.model as ModelConfig,
|
||||||
activeContextTime: body.activeContextTime,
|
activeContextTime: body.activeContextTime,
|
||||||
channels: body.channels,
|
channels: body.channels,
|
||||||
currentChannel: body.currentChannel,
|
currentChannel: body.currentChannel,
|
||||||
identity: body.identity,
|
identity: body.identity,
|
||||||
auth: {
|
auth,
|
||||||
bearer: bearer!,
|
|
||||||
baseUrl: getBaseUrl(),
|
|
||||||
},
|
|
||||||
skills: body.usableSkills,
|
skills: body.usableSkills,
|
||||||
mcpConnections: body.mcpConnections,
|
mcpConnections: body.mcpConnections,
|
||||||
inbox: body.inbox,
|
inbox: body.inbox,
|
||||||
|
|||||||
@@ -159,6 +159,50 @@ func ChatTokenFromContext(c echo.Context) (ChatToken, error) {
|
|||||||
return info, nil
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshTokenFromContext extracts the current token from context and issues a new one
|
||||||
|
// with the same claims but a renewed expiration time.
|
||||||
|
func RefreshTokenFromContext(c echo.Context, secret string, defaultExpiresIn time.Duration) (string, time.Time, error) {
|
||||||
|
token, ok := c.Get("user").(*jwt.Token)
|
||||||
|
if !ok || token == nil || !token.Valid {
|
||||||
|
return "", time.Time{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return "", time.Time{}, echo.NewHTTPError(http.StatusUnauthorized, "invalid token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate original duration if possible
|
||||||
|
expiresIn := defaultExpiresIn
|
||||||
|
if expRaw, ok := claims["exp"].(float64); ok {
|
||||||
|
if iatRaw, ok := claims["iat"].(float64); ok {
|
||||||
|
duration := time.Duration(expRaw-iatRaw) * time.Second
|
||||||
|
if duration > 0 {
|
||||||
|
expiresIn = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
expiresAt := now.Add(expiresIn)
|
||||||
|
|
||||||
|
// Create new claims, copying over existing ones but updating time bounds
|
||||||
|
newClaims := jwt.MapClaims{}
|
||||||
|
for k, v := range claims {
|
||||||
|
newClaims[k] = v
|
||||||
|
}
|
||||||
|
newClaims["iat"] = now.Unix()
|
||||||
|
newClaims["exp"] = expiresAt.Unix()
|
||||||
|
|
||||||
|
newToken := jwt.NewWithClaims(jwt.SigningMethodHS256, newClaims)
|
||||||
|
signed, err := newToken.SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
return "", time.Time{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return signed, expiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
func claimString(claims jwt.MapClaims, key string) string {
|
func claimString(claims jwt.MapClaims, key string) string {
|
||||||
raw, ok := claims[key]
|
raw, ok := claims[key]
|
||||||
if !ok || raw == nil {
|
if !ok || raw == nil {
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRefreshTokenFromContext(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
secret := "test-secret"
|
||||||
|
userID := "user-123"
|
||||||
|
|
||||||
|
// Create an initial token with a 5-minute lifespan
|
||||||
|
initialDuration := 5 * time.Minute
|
||||||
|
initialTokenStr, _, err := GenerateToken(userID, secret, initialDuration)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse the token to place it into the echo context
|
||||||
|
token, err := jwt.Parse(initialTokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(secret), nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
c.Set("user", token)
|
||||||
|
|
||||||
|
// Simulate some time passing to ensure the new token has a different 'iat' and 'exp'
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
// Run the refresh function
|
||||||
|
defaultDuration := 1 * time.Hour
|
||||||
|
newTokenStr, newExpiresAt, err := RefreshTokenFromContext(c, secret, defaultDuration)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, newTokenStr)
|
||||||
|
|
||||||
|
// Parse the original token claims for comparison
|
||||||
|
originalClaims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
assert.True(t, ok)
|
||||||
|
origIat := int64(originalClaims["iat"].(float64))
|
||||||
|
origExp := int64(originalClaims["exp"].(float64))
|
||||||
|
|
||||||
|
// Parse the new token
|
||||||
|
newToken, err := jwt.Parse(newTokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(secret), nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, newToken.Valid)
|
||||||
|
|
||||||
|
newClaims, ok := newToken.Claims.(jwt.MapClaims)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// Ensure standard payload claims are retained
|
||||||
|
assert.Equal(t, userID, newClaims[claimSubject])
|
||||||
|
assert.Equal(t, userID, newClaims[claimUserID])
|
||||||
|
|
||||||
|
// Validate the new time bounds
|
||||||
|
newIat := int64(newClaims["iat"].(float64))
|
||||||
|
newExp := int64(newClaims["exp"].(float64))
|
||||||
|
|
||||||
|
// 1. Ensure time has advanced
|
||||||
|
assert.Greater(t, newIat, origIat)
|
||||||
|
|
||||||
|
// 2. Ensure the refreshed token has a positive lifetime and does not exceed the configured default duration
|
||||||
|
lifetimeSeconds := newExp - newIat
|
||||||
|
assert.Greater(t, lifetimeSeconds, int64(0))
|
||||||
|
assert.LessOrEqual(t, lifetimeSeconds, int64(defaultDuration.Seconds()))
|
||||||
|
|
||||||
|
// 3. Ensure the return value matches the claim
|
||||||
|
assert.Equal(t, newExpiresAt.Unix(), newExp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenFromContext_MissingUser(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
secret := "test-secret"
|
||||||
|
defaultDuration := 1 * time.Hour
|
||||||
|
|
||||||
|
// Context without the "user" key
|
||||||
|
_, _, err := RefreshTokenFromContext(c, secret, defaultDuration)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
httpErr, ok := err.(*echo.HTTPError)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, httpErr.Code)
|
||||||
|
assert.Equal(t, "invalid token", httpErr.Message)
|
||||||
|
}
|
||||||
@@ -46,6 +46,7 @@ func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecre
|
|||||||
|
|
||||||
func (h *AuthHandler) Register(e *echo.Echo) {
|
func (h *AuthHandler) Register(e *echo.Echo) {
|
||||||
e.POST("/auth/login", h.Login)
|
e.POST("/auth/login", h.Login)
|
||||||
|
e.POST("/auth/refresh", h.Refresh)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login godoc
|
// Login godoc
|
||||||
@@ -103,3 +104,35 @@ func (h *AuthHandler) Login(c echo.Context) error {
|
|||||||
DisplayName: account.DisplayName,
|
DisplayName: account.DisplayName,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RefreshResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresAt string `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh godoc
|
||||||
|
// @Summary Refresh Token
|
||||||
|
// @Description Issue a new JWT using the existing claims with updated expiration
|
||||||
|
// @Tags auth
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Success 200 {object} RefreshResponse
|
||||||
|
// @Failure 401 {object} ErrorResponse
|
||||||
|
// @Failure 500 {object} ErrorResponse
|
||||||
|
// @Router /auth/refresh [post]
|
||||||
|
func (h *AuthHandler) Refresh(c echo.Context) error {
|
||||||
|
if strings.TrimSpace(h.jwtSecret) == "" {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "jwt secret not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, expiresAt, err := auth.RefreshTokenFromContext(c, h.jwtSecret, h.expiresIn)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusUnauthorized, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, RefreshResponse{
|
||||||
|
AccessToken: token,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user