mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat: refactor logging system to slog with DI and component tagging
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ctxKey struct{}
|
||||
|
||||
var (
|
||||
L *slog.Logger = slog.Default()
|
||||
logKey = ctxKey{}
|
||||
)
|
||||
|
||||
// Init 初始化全局日志
|
||||
func Init(level, format string) {
|
||||
var handler slog.Handler
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: parseLevel(level),
|
||||
}
|
||||
|
||||
if strings.ToLower(format) == "json" {
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
} else {
|
||||
handler = slog.NewTextHandler(os.Stdout, opts)
|
||||
}
|
||||
|
||||
L = slog.New(handler)
|
||||
slog.SetDefault(L)
|
||||
}
|
||||
|
||||
// FromContext 从 context 中获取 logger,如果不存在则返回全局 logger
|
||||
func FromContext(ctx context.Context) *slog.Logger {
|
||||
if l, ok := ctx.Value(logKey).(*slog.Logger); ok {
|
||||
return l
|
||||
}
|
||||
return L
|
||||
}
|
||||
|
||||
// WithContext 将 logger 注入 context
|
||||
func WithContext(ctx context.Context, l *slog.Logger) context.Context {
|
||||
return context.WithValue(ctx, logKey, l)
|
||||
}
|
||||
|
||||
func parseLevel(level string) slog.Level {
|
||||
switch strings.ToLower(level) {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "info":
|
||||
return slog.LevelInfo
|
||||
case "warn":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
default:
|
||||
return slog.LevelInfo
|
||||
}
|
||||
}
|
||||
|
||||
// 快捷方法,支持强类型 slog.Attr 或松散的 key-value 对
|
||||
func Debug(msg string, args ...any) { L.Debug(msg, args...) }
|
||||
func Info(msg string, args ...any) { L.Info(msg, args...) }
|
||||
func Warn(msg string, args ...any) { L.Warn(msg, args...) }
|
||||
func Error(msg string, args ...any) { L.Error(msg, args...) }
|
||||
@@ -0,0 +1,55 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInitAndLogging(t *testing.T) {
|
||||
// 测试 JSON 格式
|
||||
Init("debug", "json")
|
||||
|
||||
if L.Enabled(context.Background(), slog.LevelDebug) != true {
|
||||
t.Error("expected debug level to be enabled")
|
||||
}
|
||||
|
||||
// 验证是否能正常输出(不崩溃)
|
||||
Info("test info message", "key", "value")
|
||||
}
|
||||
|
||||
func TestContextLogger(t *testing.T) {
|
||||
Init("info", "text")
|
||||
|
||||
// 创建一个带特定属性的 logger
|
||||
expectedKey := "request_id"
|
||||
expectedValue := "12345"
|
||||
customLogger := L.With(expectedKey, expectedValue)
|
||||
|
||||
ctx := WithContext(context.Background(), customLogger)
|
||||
extracted := FromContext(ctx)
|
||||
|
||||
// 这里简单验证提取出来的是否是同一个(或者功能一致)
|
||||
if extracted == nil {
|
||||
t.Fatal("extracted logger should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected slog.Level
|
||||
}{
|
||||
{"debug", slog.LevelDebug},
|
||||
{"INFO", slog.LevelInfo},
|
||||
{"Warn", slog.LevelWarn},
|
||||
{"error", slog.LevelError},
|
||||
{"unknown", slog.LevelInfo},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := parseLevel(tt.input); got != tt.expected {
|
||||
t.Errorf("parseLevel(%s) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user