feat: improve api

This commit is contained in:
Acbox
2026-01-10 22:18:50 +08:00
parent 661d742750
commit fee657ddd2
22 changed files with 655 additions and 228 deletions
+1
View File
@@ -16,6 +16,7 @@
"@elysiajs/cron": "^1.4.1",
"@elysiajs/eden": "^1.4.6",
"@elysiajs/jwt": "^1.2.0",
"@elysiajs/openapi": "^1.4.13",
"@memohome/agent": "workspace:*",
"@memohome/db": "workspace:*",
"@memohome/memory": "workspace:*",
+4 -1
View File
@@ -1,11 +1,14 @@
import { Elysia } from 'elysia'
import { corsMiddleware } from './middlewares'
import { corsMiddleware, errorMiddleware } from './middlewares'
import { agentModule, authModule, modelModule, settingsModule, userModule } from './modules'
import { memoryModule } from './modules/memory'
import openapi from '@elysiajs/openapi'
const port = process.env.API_SERVER_PORT || 7002
export const app = new Elysia()
.use(errorMiddleware)
.use(openapi())
.use(corsMiddleware)
.use(authModule)
.use(agentModule)
+37 -29
View File
@@ -2,20 +2,40 @@ import { Elysia } from 'elysia'
import { bearer } from '@elysiajs/bearer'
import { jwt } from '@elysiajs/jwt'
/**
* JWT 配置常量
*/
const JWT_CONFIG = {
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
}
/**
* 用户信息类型
*/
export type AuthUser = {
userId: string
username: string
role: string
}
/**
* 共享的基础认证插件
* 提供 JWT 和 Bearer token 功能
*/
export const jwtPlugin = new Elysia({ name: 'jwt-plugin' })
.use(jwt(JWT_CONFIG))
.use(bearer())
/**
* 认证中间件
* 验证 Bearer token 并将用户信息注入到 context 中
*/
export const authMiddleware = new Elysia({ name: 'auth' })
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(jwt(JWT_CONFIG))
.use(bearer())
.derive(async ({ bearer, jwt, set }) => {
.derive({ as: 'scoped' }, async ({ bearer, jwt, set }) => {
if (!bearer) {
set.status = 401
throw new Error('No bearer token provided')
@@ -33,7 +53,7 @@ export const authMiddleware = new Elysia({ name: 'auth' })
userId: payload.userId as string,
username: payload.username as string,
role: payload.role as string,
},
} as AuthUser,
}
})
@@ -42,23 +62,17 @@ export const authMiddleware = new Elysia({ name: 'auth' })
* 如果有 token 则验证,没有 token 则继续(user 为 null
*/
export const optionalAuthMiddleware = new Elysia({ name: 'optional-auth' })
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(jwt(JWT_CONFIG))
.use(bearer())
.derive(async ({ bearer, jwt }) => {
.derive({ as: 'scoped' }, async ({ bearer, jwt }) => {
if (!bearer) {
return { user: null }
return { user: null as AuthUser | null }
}
const payload = await jwt.verify(bearer)
if (!payload) {
return { user: null }
return { user: null as AuthUser | null }
}
return {
@@ -66,7 +80,7 @@ export const optionalAuthMiddleware = new Elysia({ name: 'optional-auth' })
userId: payload.userId as string,
username: payload.username as string,
role: payload.role as string,
},
} as AuthUser | null,
}
})
@@ -75,15 +89,9 @@ export const optionalAuthMiddleware = new Elysia({ name: 'optional-auth' })
* 验证 token 并检查用户是否为管理员
*/
export const adminMiddleware = new Elysia({ name: 'admin' })
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(jwt(JWT_CONFIG))
.use(bearer())
.derive(async ({ bearer, jwt, set }) => {
.derive({ as: 'scoped' }, async ({ bearer, jwt, set }) => {
if (!bearer) {
set.status = 401
throw new Error('No bearer token provided')
@@ -96,7 +104,7 @@ export const adminMiddleware = new Elysia({ name: 'admin' })
throw new Error('Invalid or expired token')
}
const user = {
const user: AuthUser = {
userId: payload.userId as string,
username: payload.username as string,
role: payload.role as string,
+133
View File
@@ -0,0 +1,133 @@
import { Elysia } from 'elysia'
/**
*
*/
export interface ErrorResponse {
success: false
error: string
code?: string
details?: unknown
}
/**
*
*/
export interface SuccessResponse<T = unknown> {
success: true
data: T
message?: string
}
/**
*
*
*/
export const errorMiddleware = new Elysia({ name: 'error' })
.onError(({ code, error, set }) => {
console.error('[Error]', code, error)
// 根据不同的错误类型设置不同的状态码和响应
switch (code) {
case 'VALIDATION':
set.status = 400
return {
success: false,
error: 'Validation failed',
code: 'VALIDATION_ERROR',
details: error.message,
} satisfies ErrorResponse
case 'NOT_FOUND':
set.status = 404
return {
success: false,
error: 'Resource not found',
code: 'NOT_FOUND',
} satisfies ErrorResponse
case 'PARSE':
set.status = 400
return {
success: false,
error: 'Invalid request format',
code: 'PARSE_ERROR',
details: error.message,
} satisfies ErrorResponse
case 'INTERNAL_SERVER_ERROR':
set.status = 500
return {
success: false,
error: 'Internal server error',
code: 'INTERNAL_SERVER_ERROR',
} satisfies ErrorResponse
case 'UNKNOWN':
default:
// 处理自定义错误
if (error instanceof Error) {
const message = error.message
// 401 未授权错误
if (
message.includes('No bearer token') ||
message.includes('Invalid or expired token')
) {
set.status = 401
return {
success: false,
error: message,
code: 'UNAUTHORIZED',
} satisfies ErrorResponse
}
// 403 权限不足错误
if (message.includes('Forbidden') || message.includes('Admin access required')) {
set.status = 403
return {
success: false,
error: message,
code: 'FORBIDDEN',
} satisfies ErrorResponse
}
// 409 冲突错误(如用户已存在)
if (message.includes('already exists')) {
set.status = 409
return {
success: false,
error: message,
code: 'CONFLICT',
} satisfies ErrorResponse
}
// 404 未找到错误
if (message.includes('not found')) {
set.status = 404
return {
success: false,
error: message,
code: 'NOT_FOUND',
} satisfies ErrorResponse
}
// 默认 500 服务器错误
set.status = 500
return {
success: false,
error: message,
code: 'ERROR',
} satisfies ErrorResponse
}
// 未知错误
set.status = 500
return {
success: false,
error: 'An unexpected error occurred',
code: 'UNKNOWN_ERROR',
} satisfies ErrorResponse
}
})
+2 -1
View File
@@ -1,2 +1,3 @@
export * from './auth'
export * from './cors'
export * from './cors'
export * from './error'
+2 -31
View File
@@ -1,6 +1,5 @@
import Elysia from 'elysia'
import { bearer } from '@elysiajs/bearer'
import { jwt } from '@elysiajs/jwt'
import { authMiddleware } from '../../middlewares/auth'
import { AgentStreamModel } from './model'
import { createAgentStream } from './service'
import { getChatModel, getEmbeddingModel, getSummaryModel } from '../model/service'
@@ -10,35 +9,7 @@ import { ChatModel, EmbeddingModel } from '@memohome/shared'
export const agentModule = new Elysia({
prefix: '/agent',
})
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(bearer())
.derive(async ({ bearer, jwt, set }) => {
if (!bearer) {
set.status = 401
throw new Error('No bearer token provided')
}
const payload = await jwt.verify(bearer)
if (!payload) {
set.status = 401
throw new Error('Invalid or expired token')
}
return {
user: {
userId: payload.userId as string,
username: payload.username as string,
role: payload.role as string,
},
}
})
.use(authMiddleware)
// Stream agent conversation
.post('/stream', async ({ user, body, set }) => {
try {
+2 -10
View File
@@ -1,20 +1,12 @@
import Elysia from 'elysia'
import { bearer } from '@elysiajs/bearer'
import { jwt } from '@elysiajs/jwt'
import { jwtPlugin } from '../../middlewares/auth'
import { LoginModel } from './model'
import { validateUser } from './service'
export const authModule = new Elysia({
prefix: '/auth',
})
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(bearer())
.use(jwtPlugin)
// Login endpoint
.post('/login', async ({ body, jwt, set }) => {
try {
+1 -1
View File
@@ -63,7 +63,7 @@ export const validateUser = async (username: string, password: string) => {
}
// 检查账户是否激活
if (user.isActive !== 'true') {
if (!user.isActive) {
return null
}
+2 -31
View File
@@ -1,6 +1,5 @@
import Elysia from 'elysia'
import { bearer } from '@elysiajs/bearer'
import { jwt } from '@elysiajs/jwt'
import { authMiddleware } from '../../middlewares/auth'
import { messageModule } from './message'
import { AddMemoryModel, SearchMemoryModel } from './model'
import { addMemory, searchMemory } from './service'
@@ -9,35 +8,7 @@ import { MemoryUnit } from '@memohome/memory'
export const memoryModule = new Elysia({
prefix: '/memory',
})
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(bearer())
.derive(async ({ bearer, jwt, set }) => {
if (!bearer) {
set.status = 401
throw new Error('No bearer token provided')
}
const payload = await jwt.verify(bearer)
if (!payload) {
set.status = 401
throw new Error('Invalid or expired token')
}
return {
user: {
userId: payload.userId as string,
username: payload.username as string,
role: payload.role as string,
},
}
})
.use(authMiddleware)
.use(messageModule)
// Add memory for current user
.post('/', async ({ user, body, set }) => {
@@ -1,20 +1,12 @@
import Elysia from 'elysia'
import { bearer } from '@elysiajs/bearer'
import { jwt } from '@elysiajs/jwt'
import { authMiddleware } from '../../../middlewares/auth'
import { GetMemoryMessageFilterModel, GetMemoryMessageModel } from './model'
import { getMemoryMessages, getMemoryMessagesFilter } from './service'
export const messageModule = new Elysia({
prefix: '/message',
})
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(bearer())
.use(authMiddleware)
.derive(async ({ bearer, jwt, set }) => {
if (!bearer) {
set.status = 401
+85 -62
View File
@@ -1,4 +1,5 @@
import Elysia from 'elysia'
import { adminMiddleware, optionalAuthMiddleware } from '../../middlewares/auth'
import {
CreateModelModel,
UpdateModelModel,
@@ -21,13 +22,24 @@ import { Model } from '@memohome/shared'
export const modelModule = new Elysia({
prefix: '/model',
})
// 公开的读取接口
.use(optionalAuthMiddleware)
// Get all models
.get('/', async () => {
.get('/', async ({ query }) => {
try {
const models = await getModels()
const page = parseInt(query.page as string) || 1
const limit = parseInt(query.limit as string) || 10
const sortOrder = (query.sortOrder as string) || 'desc'
const result = await getModels({
page,
limit,
sortOrder: sortOrder as 'asc' | 'desc',
})
return {
success: true,
data: models,
...result,
}
} catch (error) {
return {
@@ -58,65 +70,6 @@ export const modelModule = new Elysia({
}
}
}, GetModelByIdModel)
// Create new model
.post('/', async ({ body }) => {
try {
const newModel = await createModel(body as Model)
return {
success: true,
data: newModel,
}
} catch (error) {
return {
success: false,
error: error instanceof Error ? error.message : 'Failed to create model',
}
}
}, CreateModelModel)
// Update model
.put('/:id', async ({ params, body }) => {
try {
const { id } = params
const updatedModel = await updateModel(id, body as Model)
if (!updatedModel) {
return {
success: false,
error: 'Model not found',
}
}
return {
success: true,
data: updatedModel,
}
} catch (error) {
return {
success: false,
error: error instanceof Error ? error.message : 'Failed to update model',
}
}
}, UpdateModelModel)
// Delete model
.delete('/:id', async ({ params }) => {
try {
const { id } = params
const deletedModel = await deleteModel(id)
if (!deletedModel) {
return {
success: false,
error: 'Model not found',
}
}
return {
success: true,
data: deletedModel,
}
} catch (error) {
return {
success: false,
error: error instanceof Error ? error.message : 'Failed to delete model',
}
}
}, DeleteModelModel)
// Get default chat model
.get('/chat/default', async ({ query }) => {
try {
@@ -183,3 +136,73 @@ export const modelModule = new Elysia({
}
}
}, GetDefaultModelModel)
// 管理员权限的写入接口
.guard(
{
beforeHandle: () => {
// This will be overridden by adminMiddleware
},
},
(app) =>
app
.use(adminMiddleware)
// Create new model
.post('/', async ({ body }) => {
try {
const newModel = await createModel(body as Model)
return {
success: true,
data: newModel,
}
} catch (error) {
return {
success: false,
error: error instanceof Error ? error.message : 'Failed to create model',
}
}
}, CreateModelModel)
// Update model
.put('/:id', async ({ params, body }) => {
try {
const { id } = params
const updatedModel = await updateModel(id, body as Model)
if (!updatedModel) {
return {
success: false,
error: 'Model not found',
}
}
return {
success: true,
data: updatedModel,
}
} catch (error) {
return {
success: false,
error: error instanceof Error ? error.message : 'Failed to update model',
}
}
}, UpdateModelModel)
// Delete model
.delete('/:id', async ({ params }) => {
try {
const { id } = params
const deletedModel = await deleteModel(id)
if (!deletedModel) {
return {
success: false,
error: 'Model not found',
}
}
return {
success: true,
data: deletedModel,
}
} catch (error) {
return {
success: false,
error: error instanceof Error ? error.message : 'Failed to delete model',
}
}
}, DeleteModelModel)
)
+35 -4
View File
@@ -1,12 +1,43 @@
import { db } from '@memohome/db'
import { model } from '@memohome/db/schema'
import { Model } from '@memohome/shared'
import { eq } from 'drizzle-orm'
import { eq, sql, desc, asc } from 'drizzle-orm'
import { getSettings } from '@/modules/settings/service'
import { calculateOffset, createPaginatedResult, type PaginatedResult } from '../../utils/pagination'
export const getModels = async () => {
const models = await db.select().from(model)
return models
/**
*
*/
type ModelListItem = {
id: string
model: Model
}
export const getModels = async (params?: {
page?: number
limit?: number
sortOrder?: 'asc' | 'desc'
}): Promise<PaginatedResult<ModelListItem>> => {
const page = params?.page || 1
const limit = params?.limit || 10
const sortOrder = params?.sortOrder || 'desc'
const offset = calculateOffset(page, limit)
// 获取总数
const [{ count }] = await db
.select({ count: sql<number>`count(*)` })
.from(model)
// 获取分页数据(按 id 排序,因为 model 表没有 createdAt
const orderFn = sortOrder === 'desc' ? desc : asc
const models = await db
.select()
.from(model)
.orderBy(orderFn(model.id))
.limit(limit)
.offset(offset)
return createPaginatedResult(models, Number(count), page, limit)
}
export const getModelById = async (id: string) => {
+2 -31
View File
@@ -1,41 +1,12 @@
import Elysia from 'elysia'
import { bearer } from '@elysiajs/bearer'
import { jwt } from '@elysiajs/jwt'
import { authMiddleware } from '../../middlewares/auth'
import { UpdateSettingsModel } from './model'
import { getSettings, upsertSettings } from './service'
export const settingsModule = new Elysia({
prefix: '/settings',
})
.use(
jwt({
name: 'jwt',
secret: process.env.JWT_SECRET || 'your-secret-key-change-in-production',
exp: process.env.JWT_EXPIRES_IN || '7d',
})
)
.use(bearer())
.derive(async ({ bearer, jwt, set }) => {
if (!bearer) {
set.status = 401
throw new Error('No bearer token provided')
}
const payload = await jwt.verify(bearer)
if (!payload) {
set.status = 401
throw new Error('Invalid or expired token')
}
return {
user: {
userId: payload.userId as string,
username: payload.username as string,
role: payload.role as string,
},
}
})
.use(authMiddleware)
// Get current user's settings
.get('/', async ({ user, set }) => {
try {
+14 -3
View File
@@ -22,12 +22,23 @@ export const userModule = new Elysia({
// 使用管理员中间件保护所有路由
.use(adminMiddleware)
// Get all users
.get('/', async () => {
.get('/', async ({ query }) => {
try {
const userList = await getUsers()
const page = parseInt(query.page as string) || 1
const limit = parseInt(query.limit as string) || 10
const sortBy = query.sortBy as string || 'createdAt'
const sortOrder = (query.sortOrder as string) || 'desc'
const result = await getUsers({
page,
limit,
sortBy,
sortOrder: sortOrder as 'asc' | 'desc',
})
return {
success: true,
data: userList,
...result,
}
} catch (error) {
return {
+1 -1
View File
@@ -19,7 +19,7 @@ const UpdateUserSchema = z.object({
role: UserRoleSchema.optional(),
displayName: z.string().optional(),
avatarUrl: z.string().url('Invalid URL format').optional(),
isActive: z.enum(['true', 'false']).optional(),
isActive: z.boolean().optional(),
})
// 更新密码的 Schema
+50 -5
View File
@@ -1,12 +1,55 @@
import { db } from '@memohome/db'
import { users, settings } from '@memohome/db/schema'
import { eq } from 'drizzle-orm'
import { eq, sql, desc, asc } from 'drizzle-orm'
import type { CreateUserInput, UpdateUserInput } from './model'
import { calculateOffset, createPaginatedResult, type PaginatedResult } from '../../utils/pagination'
/**
*
*
*/
export const getUsers = async () => {
type UserListItem = {
id: string
username: string
email: string | null
role: 'admin' | 'member'
displayName: string | null
avatarUrl: string | null
isActive: boolean
createdAt: Date
updatedAt: Date
lastLoginAt: Date | null
}
/**
*
*/
export const getUsers = async (params?: {
page?: number
limit?: number
sortBy?: string
sortOrder?: 'asc' | 'desc'
}): Promise<PaginatedResult<UserListItem>> => {
const page = params?.page || 1
const limit = params?.limit || 10
const sortBy = params?.sortBy || 'createdAt'
const sortOrder = params?.sortOrder || 'desc'
const offset = calculateOffset(page, limit)
// 获取总数
const [{ count }] = await db
.select({ count: sql<number>`count(*)` })
.from(users)
// 动态排序
const orderColumn = sortBy === 'username' ? users.username :
sortBy === 'email' ? users.email :
sortBy === 'role' ? users.role :
sortBy === 'updatedAt' ? users.updatedAt :
users.createdAt
const orderFn = sortOrder === 'desc' ? desc : asc
// 获取分页数据
const userList = await db
.select({
id: users.id,
@@ -21,9 +64,11 @@ export const getUsers = async () => {
lastLoginAt: users.lastLoginAt,
})
.from(users)
.orderBy(users.createdAt)
.orderBy(orderFn(orderColumn))
.limit(limit)
.offset(offset)
return userList
return createPaginatedResult(userList, Number(count), page, limit)
}
/**
+73
View File
@@ -0,0 +1,73 @@
/**
*
*/
export interface PaginationParams {
page?: number
limit?: number
sortBy?: string
sortOrder?: 'asc' | 'desc'
}
/**
*
*/
export interface PaginatedResult<T> {
items: T[]
pagination: {
page: number
limit: number
total: number
totalPages: number
hasNext: boolean
hasPrev: boolean
}
}
/**
*
*/
export function parsePaginationParams(query: Record<string, any>): Required<PaginationParams> {
const page = Math.max(1, parseInt(query.page as string) || 1)
const limit = Math.min(100, Math.max(1, parseInt(query.limit as string) || 10))
const sortBy = query.sortBy as string || 'createdAt'
const sortOrder = (query.sortOrder as string)?.toLowerCase() === 'desc' ? 'desc' : 'asc'
return {
page,
limit,
sortBy,
sortOrder,
}
}
/**
*
*/
export function createPaginatedResult<T>(
items: T[],
total: number,
page: number,
limit: number
): PaginatedResult<T> {
const totalPages = Math.ceil(total / limit)
return {
items,
pagination: {
page,
limit,
total,
totalPages,
hasNext: page < totalPages,
hasPrev: page > 1,
},
}
}
/**
*
*/
export function calculateOffset(page: number, limit: number): number {
return (page - 1) * limit
}
+3 -2
View File
@@ -1,4 +1,5 @@
import { pgTable, timestamp, uuid, jsonb, text } from 'drizzle-orm/pg-core'
import { pgTable, timestamp, uuid, jsonb } from 'drizzle-orm/pg-core'
import { users } from './users'
export const history = pgTable(
'history',
@@ -6,6 +7,6 @@ export const history = pgTable(
id: uuid('id').primaryKey().defaultRandom(),
messages: jsonb('messages').notNull(),
timestamp: timestamp('timestamp').notNull(),
user: text('user').notNull(),
user: uuid('user').notNull().references(() => users.id),
}
)
+2 -1
View File
@@ -1,8 +1,9 @@
import { pgTable, text, uuid, integer } from 'drizzle-orm/pg-core'
import { model } from './model'
import { users } from './users'
export const settings = pgTable('settings', {
userId: text('user_id').primaryKey(),
userId: uuid('user_id').primaryKey().references(() => users.id),
defaultChatModel: uuid('default_chat_model').references(() => model.id),
defaultEmbeddingModel: uuid('default_embedding_model').references(() => model.id),
defaultSummaryModel: uuid('default_summary_model').references(() => model.id),
+2 -2
View File
@@ -1,4 +1,4 @@
import { pgTable, pgEnum, text, timestamp, uuid } from 'drizzle-orm/pg-core'
import { pgTable, pgEnum, text, timestamp, uuid, boolean } from 'drizzle-orm/pg-core'
// 定义用户角色枚举
export const userRoleEnum = pgEnum('user_role', ['admin', 'member'])
@@ -27,7 +27,7 @@ export const users = pgTable('users', {
avatarUrl: text('avatar_url'),
// 账户状态(是否激活)
isActive: text('is_active').notNull().default('true'),
isActive: boolean('is_active').notNull().default(true),
// 创建时间
createdAt: timestamp('created_at').notNull().defaultNow(),