feat: add media asset system, channel lifecycle refactor, and chat attachments (#54)

This commit is contained in:
BBQ
2026-02-17 19:06:46 +08:00
committed by GitHub
parent 0bdc31311c
commit df7876a30c
106 changed files with 7942 additions and 1274 deletions
+14
View File
@@ -0,0 +1,14 @@
package media
import "errors"
var (
// ErrAssetNotFound indicates the requested media asset does not exist.
ErrAssetNotFound = errors.New("media asset not found")
// ErrProviderUnavailable indicates the storage provider is not configured or reachable.
ErrProviderUnavailable = errors.New("storage provider unavailable")
// ErrAssetTooLarge indicates the payload exceeds the configured max asset size.
ErrAssetTooLarge = errors.New("media asset too large")
// ErrPathTraversal indicates a storage key attempted directory traversal.
ErrPathTraversal = errors.New("path traversal is forbidden")
)
+33
View File
@@ -0,0 +1,33 @@
package media
import (
"fmt"
"io"
)
const (
// MaxAssetBytes is the global max accepted payload size.
MaxAssetBytes int64 = 200 * 1024 * 1024
)
// ReadAllWithLimit reads from reader and rejects payloads larger than maxBytes.
func ReadAllWithLimit(reader io.Reader, maxBytes int64) ([]byte, error) {
if reader == nil {
return nil, fmt.Errorf("reader is required")
}
if maxBytes <= 0 {
return nil, fmt.Errorf("max bytes must be greater than 0")
}
limited := &io.LimitedReader{
R: reader,
N: maxBytes + 1,
}
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("%w: max %d bytes", ErrAssetTooLarge, maxBytes)
}
return data, nil
}
+60
View File
@@ -0,0 +1,60 @@
package media
import (
"bytes"
"errors"
"testing"
)
func TestReadAllWithLimit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
payload []byte
maxBytes int64
wantErr bool
errTooBig bool
}{
{
name: "within limit",
payload: []byte("hello"),
maxBytes: 8,
},
{
name: "over limit",
payload: []byte("0123456789"),
maxBytes: 5,
wantErr: true,
errTooBig: true,
},
{
name: "exact limit",
payload: []byte("12345"),
maxBytes: 5,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := ReadAllWithLimit(bytes.NewReader(tt.payload), tt.maxBytes)
if tt.wantErr {
if err == nil {
t.Fatalf("expected error")
}
if tt.errTooBig && !errors.Is(err, ErrAssetTooLarge) {
t.Fatalf("expected ErrAssetTooLarge, got %v", err)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != string(tt.payload) {
t.Fatalf("unexpected payload: %q", string(got))
}
})
}
}
@@ -0,0 +1,112 @@
// Package containerfs implements media.StorageProvider for bot containers
// backed by host-side bind mounts. Writing to <dataRoot>/bots/<bot_id>/media/<subpath>
// on the host makes the file available at /data/media/<subpath> inside the container.
package containerfs
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
const containerMediaRoot = "/data/media"
// Provider stores media assets via the host-side bind mount path
// that maps to /data inside bot containers.
type Provider struct {
dataRoot string
}
// New creates a container-based storage provider.
// dataRoot is the host directory that contains per-bot data (e.g. "data").
func New(dataRoot string) (*Provider, error) {
abs, err := filepath.Abs(dataRoot)
if err != nil {
return nil, fmt.Errorf("resolve data root: %w", err)
}
return &Provider{dataRoot: abs}, nil
}
// Put writes data to the host bind mount path for the bot container.
func (p *Provider) Put(_ context.Context, key string, reader io.Reader) error {
dest, err := p.hostPath(key)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil {
return fmt.Errorf("create parent dir: %w", err)
}
f, err := os.Create(dest)
if err != nil {
return fmt.Errorf("create file: %w", err)
}
defer f.Close()
if _, err := io.Copy(f, reader); err != nil {
return fmt.Errorf("write file: %w", err)
}
return nil
}
// Open reads a file from the host bind mount path.
func (p *Provider) Open(_ context.Context, key string) (io.ReadCloser, error) {
dest, err := p.hostPath(key)
if err != nil {
return nil, err
}
f, err := os.Open(dest)
if err != nil {
return nil, fmt.Errorf("open file: %w", err)
}
return f, nil
}
// Delete removes a file from the host bind mount path.
func (p *Provider) Delete(_ context.Context, key string) error {
dest, err := p.hostPath(key)
if err != nil {
return err
}
if err := os.Remove(dest); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("delete file: %w", err)
}
return nil
}
// AccessPath returns the container-internal path for a storage key.
// Key format: "<bot_id>/<subpath>" → "/data/media/<subpath>".
func (p *Provider) AccessPath(key string) string {
sub := key
if idx := strings.IndexByte(sub, '/'); idx >= 0 {
sub = sub[idx+1:]
}
return containerMediaRoot + "/" + sub
}
// hostPath converts a storage key into the host-side file path.
// Key format: "<bot_id>/<subpath>" → "<dataRoot>/bots/<bot_id>/media/<subpath>".
func (p *Provider) hostPath(key string) (string, error) {
clean := filepath.Clean(key)
if filepath.IsAbs(clean) {
return "", fmt.Errorf("absolute key is forbidden: %s", key)
}
if strings.HasPrefix(clean, ".."+string(filepath.Separator)) || clean == ".." {
return "", fmt.Errorf("path traversal is forbidden: %s", key)
}
idx := strings.IndexByte(clean, filepath.Separator)
if idx <= 0 {
return "", fmt.Errorf("storage key must contain bot_id prefix: %s", key)
}
botID := clean[:idx]
subPath := clean[idx+1:]
if strings.TrimSpace(botID) == "" || strings.TrimSpace(subPath) == "" {
return "", fmt.Errorf("invalid storage key: %s", key)
}
joined := filepath.Join(p.dataRoot, "bots", botID, "media", subPath)
if !strings.HasPrefix(joined, p.dataRoot+string(filepath.Separator)) {
return "", fmt.Errorf("path escapes data root: %s", key)
}
return joined, nil
}
@@ -0,0 +1,116 @@
package containerfs
import (
"bytes"
"context"
"io"
"os"
"path/filepath"
"testing"
)
func TestProvider_HostPath(t *testing.T) {
t.Parallel()
p := &Provider{dataRoot: "/srv/data"}
tests := []struct {
key string
want string
wantErr bool
}{
{key: "bot-1/image/ab12/ab12cd.png", want: "/srv/data/bots/bot-1/media/image/ab12/ab12cd.png"},
{key: "/absolute/path", wantErr: true},
{key: "../escape", wantErr: true},
{key: "nosubpath", wantErr: true},
{key: "", wantErr: true},
}
for _, tt := range tests {
got, err := p.hostPath(tt.key)
if tt.wantErr {
if err == nil {
t.Errorf("hostPath(%q) expected error", tt.key)
}
continue
}
if err != nil {
t.Errorf("hostPath(%q) unexpected error: %v", tt.key, err)
continue
}
if got != tt.want {
t.Errorf("hostPath(%q) = %q, want %q", tt.key, got, tt.want)
}
}
}
func TestProvider_AccessPath(t *testing.T) {
t.Parallel()
p := &Provider{dataRoot: "/srv/data"}
tests := []struct {
key string
want string
}{
{key: "bot-1/image/ab12/ab12cd.png", want: "/data/media/image/ab12/ab12cd.png"},
{key: "bot-1/file/xx/doc.pdf", want: "/data/media/file/xx/doc.pdf"},
}
for _, tt := range tests {
got := p.AccessPath(tt.key)
if got != tt.want {
t.Errorf("AccessPath(%q) = %q, want %q", tt.key, got, tt.want)
}
}
}
func TestProvider_PutOpenDelete(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
p, err := New(tmpDir)
if err != nil {
t.Fatalf("New failed: %v", err)
}
key := "bot-1/image/ab/test.png"
data := []byte("hello media content")
if err := p.Put(context.Background(), key, bytes.NewReader(data)); err != nil {
t.Fatalf("Put failed: %v", err)
}
hostFile := filepath.Join(tmpDir, "bots", "bot-1", "media", "image", "ab", "test.png")
if _, err := os.Stat(hostFile); err != nil {
t.Fatalf("file not found on host: %v", err)
}
reader, err := p.Open(context.Background(), key)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
got, _ := io.ReadAll(reader)
reader.Close()
if !bytes.Equal(got, data) {
t.Errorf("Open returned %q, want %q", got, data)
}
if err := p.Delete(context.Background(), key); err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, err := os.Stat(hostFile); !os.IsNotExist(err) {
t.Fatalf("file should be deleted: %v", err)
}
}
func TestProvider_PathTraversal(t *testing.T) {
t.Parallel()
p := &Provider{dataRoot: "/srv/data"}
bad := []string{
"../etc/passwd",
"/absolute/key",
"bot-1/../../escape",
}
for _, key := range bad {
if _, err := p.hostPath(key); err == nil {
t.Errorf("hostPath(%q) should reject traversal", key)
}
}
}
+356
View File
@@ -0,0 +1,356 @@
package media
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
dbpkg "github.com/memohai/memoh/internal/db"
"github.com/memohai/memoh/internal/db/sqlc"
)
// Service provides media asset persistence operations.
type Service struct {
queries *sqlc.Queries
provider StorageProvider
logger *slog.Logger
}
// NewService creates a media service with the given storage provider.
func NewService(log *slog.Logger, queries *sqlc.Queries, provider StorageProvider) *Service {
if log == nil {
log = slog.Default()
}
return &Service{
queries: queries,
provider: provider,
logger: log.With(slog.String("service", "media")),
}
}
// Ingest persists a new media asset. It hashes the content, deduplicates by
// (bot_id, content_hash), stores the bytes via the provider, and writes the
// DB record. Returns the asset (existing or newly created).
func (s *Service) Ingest(ctx context.Context, input IngestInput) (Asset, error) {
if s.provider == nil {
return Asset{}, ErrProviderUnavailable
}
if strings.TrimSpace(input.BotID) == "" {
return Asset{}, fmt.Errorf("bot id is required")
}
if input.Reader == nil {
return Asset{}, fmt.Errorf("reader is required")
}
maxBytes := input.MaxBytes
if maxBytes <= 0 {
maxBytes = MaxAssetBytes
}
contentHash, sizeBytes, tempPath, err := spoolAndHashWithLimit(input.Reader, maxBytes)
if err != nil {
return Asset{}, fmt.Errorf("read input: %w", err)
}
defer func() {
_ = os.Remove(tempPath)
}()
pgBotID, err := dbpkg.ParseUUID(input.BotID)
if err != nil {
return Asset{}, fmt.Errorf("invalid bot id: %w", err)
}
// Dedup: only create when hash truly not found; propagate other DB errors.
existing, err := s.queries.GetMediaAssetByHash(ctx, sqlc.GetMediaAssetByHashParams{
BotID: pgBotID,
ContentHash: contentHash,
})
if err == nil {
return convertAsset(existing), nil
}
if !errors.Is(err, pgx.ErrNoRows) {
return Asset{}, fmt.Errorf("check existing asset: %w", err)
}
ext := extensionFromMime(input.Mime)
storageKey := path.Join(
input.BotID,
string(input.MediaType),
contentHash[:4],
contentHash+ext,
)
tempFile, err := os.Open(tempPath)
if err != nil {
return Asset{}, fmt.Errorf("open temp file: %w", err)
}
defer func() {
_ = tempFile.Close()
}()
if err := s.provider.Put(ctx, storageKey, tempFile); err != nil {
return Asset{}, fmt.Errorf("store media: %w", err)
}
metaBytes, err := json.Marshal(nonNilMap(input.Metadata))
if err != nil {
metaBytes = []byte("{}")
}
row, err := s.queries.CreateMediaAsset(ctx, sqlc.CreateMediaAssetParams{
BotID: pgBotID,
ContentHash: contentHash,
MediaType: string(input.MediaType),
Mime: coalesce(input.Mime, "application/octet-stream"),
SizeBytes: sizeBytes,
StorageKey: storageKey,
OriginalName: pgtype.Text{
String: input.OriginalName,
Valid: strings.TrimSpace(input.OriginalName) != "",
},
Width: toPgInt4(input.Width),
Height: toPgInt4(input.Height),
DurationMs: toPgInt8(input.DurationMs),
Metadata: metaBytes,
})
if err != nil {
return Asset{}, fmt.Errorf("create asset record: %w", err)
}
return convertAsset(row), nil
}
// Open returns a reader for the media asset identified by ID.
func (s *Service) Open(ctx context.Context, assetID string) (io.ReadCloser, Asset, error) {
if s.provider == nil {
return nil, Asset{}, ErrProviderUnavailable
}
pgID, err := dbpkg.ParseUUID(assetID)
if err != nil {
return nil, Asset{}, fmt.Errorf("invalid asset id: %w", err)
}
row, err := s.queries.GetMediaAssetByID(ctx, pgID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, Asset{}, ErrAssetNotFound
}
return nil, Asset{}, fmt.Errorf("get asset: %w", err)
}
asset := convertAsset(row)
reader, err := s.provider.Open(ctx, asset.StorageKey)
if err != nil {
return nil, Asset{}, fmt.Errorf("open storage: %w", err)
}
return reader, asset, nil
}
// GetByID returns an asset by its ID.
func (s *Service) GetByID(ctx context.Context, assetID string) (Asset, error) {
pgID, err := dbpkg.ParseUUID(assetID)
if err != nil {
return Asset{}, fmt.Errorf("invalid asset id: %w", err)
}
row, err := s.queries.GetMediaAssetByID(ctx, pgID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return Asset{}, ErrAssetNotFound
}
return Asset{}, fmt.Errorf("get asset: %w", err)
}
return convertAsset(row), nil
}
// LinkToMessage creates a message-asset relationship.
func (s *Service) LinkToMessage(ctx context.Context, messageID, assetID, role string, ordinal int) error {
pgMsgID, err := dbpkg.ParseUUID(messageID)
if err != nil {
return fmt.Errorf("invalid message id: %w", err)
}
pgAssetID, err := dbpkg.ParseUUID(assetID)
if err != nil {
return fmt.Errorf("invalid asset id: %w", err)
}
if strings.TrimSpace(role) == "" {
role = "attachment"
}
_, err = s.queries.CreateMessageAsset(ctx, sqlc.CreateMessageAssetParams{
MessageID: pgMsgID,
AssetID: pgAssetID,
Role: role,
Ordinal: int32(ordinal),
})
return err
}
// ListMessageAssets returns all assets linked to a message.
func (s *Service) ListMessageAssets(ctx context.Context, messageID string) ([]Asset, error) {
pgMsgID, err := dbpkg.ParseUUID(messageID)
if err != nil {
return nil, fmt.Errorf("invalid message id: %w", err)
}
rows, err := s.queries.ListMessageAssets(ctx, pgMsgID)
if err != nil {
return nil, err
}
assets := make([]Asset, 0, len(rows))
for _, row := range rows {
assets = append(assets, Asset{
ID: row.AssetID.String(),
MediaType: MediaType(row.MediaType),
Mime: row.Mime,
SizeBytes: row.SizeBytes,
StorageKey: row.StorageKey,
OriginalName: dbpkg.TextToString(row.OriginalName),
Width: int(row.Width.Int32),
Height: int(row.Height.Int32),
DurationMs: row.DurationMs.Int64,
})
}
return assets, nil
}
// AccessPath returns a consumer-accessible reference for a persisted asset.
// Delegates to the storage provider to compute the format-appropriate path.
func (s *Service) AccessPath(asset Asset) string {
if s.provider == nil {
return ""
}
return s.provider.AccessPath(asset.StorageKey)
}
// --- helpers ---
func convertAsset(row sqlc.MediaAsset) Asset {
a := Asset{
ID: row.ID.String(),
BotID: row.BotID.String(),
ContentHash: row.ContentHash,
MediaType: MediaType(row.MediaType),
Mime: row.Mime,
SizeBytes: row.SizeBytes,
StorageKey: row.StorageKey,
CreatedAt: row.CreatedAt.Time,
}
if row.StorageProviderID.Valid {
a.StorageProviderID = row.StorageProviderID.String()
}
if row.OriginalName.Valid {
a.OriginalName = row.OriginalName.String
}
if row.Width.Valid {
a.Width = int(row.Width.Int32)
}
if row.Height.Valid {
a.Height = int(row.Height.Int32)
}
if row.DurationMs.Valid {
a.DurationMs = row.DurationMs.Int64
}
var meta map[string]any
if len(row.Metadata) > 0 {
_ = json.Unmarshal(row.Metadata, &meta)
}
a.Metadata = meta
return a
}
func extensionFromMime(mime string) string {
switch strings.ToLower(strings.TrimSpace(mime)) {
case "image/png":
return ".png"
case "image/jpeg", "image/jpg":
return ".jpg"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "audio/mpeg", "audio/mp3":
return ".mp3"
case "audio/wav":
return ".wav"
case "audio/ogg":
return ".ogg"
case "video/mp4":
return ".mp4"
case "video/webm":
return ".webm"
case "application/pdf":
return ".pdf"
default:
return ".bin"
}
}
func nonNilMap(m map[string]any) map[string]any {
if m == nil {
return map[string]any{}
}
return m
}
func coalesce(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}
func toPgInt4(v int) pgtype.Int4 {
if v == 0 {
return pgtype.Int4{}
}
return pgtype.Int4{Int32: int32(v), Valid: true}
}
func toPgInt8(v int64) pgtype.Int8 {
if v == 0 {
return pgtype.Int8{}
}
return pgtype.Int8{Int64: v, Valid: true}
}
func spoolAndHashWithLimit(reader io.Reader, maxBytes int64) (string, int64, string, error) {
if reader == nil {
return "", 0, "", fmt.Errorf("reader is required")
}
if maxBytes <= 0 {
return "", 0, "", fmt.Errorf("max bytes must be greater than 0")
}
tempFile, err := os.CreateTemp("", "memoh-media-*")
if err != nil {
return "", 0, "", fmt.Errorf("create temp file: %w", err)
}
tempPath := tempFile.Name()
keepFile := false
defer func() {
_ = tempFile.Close()
if !keepFile {
_ = os.Remove(tempPath)
}
}()
hasher := sha256.New()
limited := &io.LimitedReader{R: reader, N: maxBytes + 1}
written, err := io.Copy(io.MultiWriter(tempFile, hasher), limited)
if err != nil {
return "", 0, "", fmt.Errorf("copy to temp file: %w", err)
}
if written > maxBytes {
return "", 0, "", fmt.Errorf("%w: max %d bytes", ErrAssetTooLarge, maxBytes)
}
if written == 0 {
return "", 0, "", fmt.Errorf("asset payload is empty")
}
keepFile = true
return hex.EncodeToString(hasher.Sum(nil)), written, tempPath, nil
}
+71
View File
@@ -0,0 +1,71 @@
package media
import (
"context"
"io"
"time"
)
// MediaType classifies the kind of media asset.
type MediaType string
const (
MediaTypeImage MediaType = "image"
MediaTypeAudio MediaType = "audio"
MediaTypeVideo MediaType = "video"
MediaTypeFile MediaType = "file"
)
// Asset is the domain representation of a persisted media object.
type Asset struct {
ID string `json:"id"`
BotID string `json:"bot_id"`
StorageProviderID string `json:"storage_provider_id,omitempty"`
ContentHash string `json:"content_hash"`
MediaType MediaType `json:"media_type"`
Mime string `json:"mime"`
SizeBytes int64 `json:"size_bytes"`
StorageKey string `json:"storage_key"`
OriginalName string `json:"original_name,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
DurationMs int64 `json:"duration_ms,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// IngestInput carries the data needed to persist a new media asset.
type IngestInput struct {
BotID string
MediaType MediaType
Mime string
OriginalName string
Width int
Height int
DurationMs int64
Metadata map[string]any
// Reader provides the raw bytes; caller is responsible for closing.
Reader io.Reader
// MaxBytes optionally overrides the media-type default size limit.
MaxBytes int64
}
// MessageAssetLink represents the relationship between a message and an asset.
type MessageAssetLink struct {
AssetID string `json:"asset_id"`
Role string `json:"role"`
Ordinal int `json:"ordinal"`
}
// StorageProvider abstracts object storage operations.
type StorageProvider interface {
// Put writes data to storage under the given key.
Put(ctx context.Context, key string, reader io.Reader) error
// Open returns a reader for the given storage key.
Open(ctx context.Context, key string) (io.ReadCloser, error)
// Delete removes the object at key.
Delete(ctx context.Context, key string) error
// AccessPath returns a consumer-accessible reference for a storage key.
// The format depends on the backend (e.g. container path, signed URL).
AccessPath(key string) string
}