mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
501 lines
13 KiB
Go
501 lines
13 KiB
Go
package memory
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/qdrant/go-client/qdrant"
|
|
)
|
|
|
|
type QdrantStore struct {
|
|
client *qdrant.Client
|
|
collection string
|
|
dimension int
|
|
baseURL string
|
|
apiKey string
|
|
timeout time.Duration
|
|
vectorNames map[string]int
|
|
usesNamedVectors bool
|
|
}
|
|
|
|
type qdrantPoint struct {
|
|
ID string `json:"id"`
|
|
Vector []float32 `json:"vector"`
|
|
VectorName string `json:"vector_name,omitempty"`
|
|
Payload map[string]interface{} `json:"payload,omitempty"`
|
|
}
|
|
|
|
func NewQdrantStore(baseURL, apiKey, collection string, dimension int, timeout time.Duration) (*QdrantStore, error) {
|
|
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if collection == "" {
|
|
collection = "memory"
|
|
}
|
|
if dimension <= 0 {
|
|
dimension = 1536
|
|
}
|
|
|
|
cfg := &qdrant.Config{
|
|
Host: host,
|
|
Port: port,
|
|
APIKey: apiKey,
|
|
UseTLS: useTLS,
|
|
}
|
|
client, err := qdrant.NewClient(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
store := &QdrantStore{
|
|
client: client,
|
|
collection: collection,
|
|
dimension: dimension,
|
|
baseURL: baseURL,
|
|
apiKey: apiKey,
|
|
timeout: timeoutOrDefault(timeout),
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
|
|
defer cancel()
|
|
if err := store.ensureCollection(ctx, nil); err != nil {
|
|
return nil, err
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func (s *QdrantStore) NewSibling(collection string, dimension int) (*QdrantStore, error) {
|
|
return NewQdrantStore(s.baseURL, s.apiKey, collection, dimension, s.timeout)
|
|
}
|
|
|
|
func NewQdrantStoreWithVectors(baseURL, apiKey, collection string, vectors map[string]int, timeout time.Duration) (*QdrantStore, error) {
|
|
host, port, useTLS, err := parseQdrantEndpoint(baseURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if collection == "" {
|
|
collection = "memory"
|
|
}
|
|
if len(vectors) == 0 {
|
|
return nil, fmt.Errorf("vectors map is required")
|
|
}
|
|
|
|
cfg := &qdrant.Config{
|
|
Host: host,
|
|
Port: port,
|
|
APIKey: apiKey,
|
|
UseTLS: useTLS,
|
|
}
|
|
client, err := qdrant.NewClient(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
store := &QdrantStore{
|
|
client: client,
|
|
collection: collection,
|
|
baseURL: baseURL,
|
|
apiKey: apiKey,
|
|
timeout: timeoutOrDefault(timeout),
|
|
vectorNames: vectors,
|
|
usesNamedVectors: true,
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeoutOrDefault(timeout))
|
|
defer cancel()
|
|
if err := store.ensureCollection(ctx, vectors); err != nil {
|
|
return nil, err
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error {
|
|
if len(points) == 0 {
|
|
return nil
|
|
}
|
|
qPoints := make([]*qdrant.PointStruct, 0, len(points))
|
|
for _, point := range points {
|
|
payload, err := qdrant.TryValueMap(point.Payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var vectors *qdrant.Vectors
|
|
if point.VectorName != "" && s.usesNamedVectors {
|
|
vectors = qdrant.NewVectorsMap(map[string]*qdrant.Vector{
|
|
point.VectorName: qdrant.NewVectorDense(point.Vector),
|
|
})
|
|
} else {
|
|
vectors = qdrant.NewVectorsDense(point.Vector)
|
|
}
|
|
qPoints = append(qPoints, &qdrant.PointStruct{
|
|
Id: qdrant.NewIDUUID(point.ID),
|
|
Vectors: vectors,
|
|
Payload: payload,
|
|
})
|
|
}
|
|
_, err := s.client.Upsert(ctx, &qdrant.UpsertPoints{
|
|
CollectionName: s.collection,
|
|
Wait: qdrant.PtrOf(true),
|
|
Points: qPoints,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, vectorName string) ([]qdrantPoint, []float64, error) {
|
|
if limit <= 0 {
|
|
limit = 10
|
|
}
|
|
filter := buildQdrantFilter(filters)
|
|
var using *string
|
|
if vectorName != "" && s.usesNamedVectors {
|
|
using = qdrant.PtrOf(vectorName)
|
|
}
|
|
results, err := s.client.Query(ctx, &qdrant.QueryPoints{
|
|
CollectionName: s.collection,
|
|
Query: qdrant.NewQueryDense(vector),
|
|
Using: using,
|
|
Limit: qdrant.PtrOf(uint64(limit)),
|
|
Filter: filter,
|
|
WithPayload: qdrant.NewWithPayload(true),
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
points := make([]qdrantPoint, 0, len(results))
|
|
scores := make([]float64, 0, len(results))
|
|
for _, scored := range results {
|
|
points = append(points, qdrantPoint{
|
|
ID: pointIDToString(scored.GetId()),
|
|
Payload: valueMapToInterface(scored.GetPayload()),
|
|
})
|
|
scores = append(scores, float64(scored.GetScore()))
|
|
}
|
|
return points, scores, nil
|
|
}
|
|
|
|
func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]interface{}, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) {
|
|
pointsBySource := make(map[string][]qdrantPoint, len(sources))
|
|
scoresBySource := make(map[string][]float64, len(sources))
|
|
if len(sources) == 0 {
|
|
return pointsBySource, scoresBySource, nil
|
|
}
|
|
for _, source := range sources {
|
|
merged := cloneFilters(filters)
|
|
if source != "" {
|
|
merged["source"] = source
|
|
}
|
|
points, scores, err := s.Search(ctx, vector, limit, merged, vectorName)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
pointsBySource[source] = points
|
|
scoresBySource[source] = scores
|
|
}
|
|
return pointsBySource, scoresBySource, nil
|
|
}
|
|
|
|
func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) {
|
|
result, err := s.client.Get(ctx, &qdrant.GetPoints{
|
|
CollectionName: s.collection,
|
|
Ids: []*qdrant.PointId{qdrant.NewIDUUID(id)},
|
|
WithPayload: qdrant.NewWithPayload(true),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(result) == 0 {
|
|
return nil, nil
|
|
}
|
|
point := result[0]
|
|
return &qdrantPoint{
|
|
ID: pointIDToString(point.GetId()),
|
|
Payload: valueMapToInterface(point.GetPayload()),
|
|
}, nil
|
|
}
|
|
|
|
func (s *QdrantStore) Delete(ctx context.Context, id string) error {
|
|
_, err := s.client.Delete(ctx, &qdrant.DeletePoints{
|
|
CollectionName: s.collection,
|
|
Wait: qdrant.PtrOf(true),
|
|
Points: qdrant.NewPointsSelectorIDs([]*qdrant.PointId{qdrant.NewIDUUID(id)}),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]interface{}) ([]qdrantPoint, error) {
|
|
if limit <= 0 {
|
|
limit = 100
|
|
}
|
|
filter := buildQdrantFilter(filters)
|
|
points, err := s.client.Scroll(ctx, &qdrant.ScrollPoints{
|
|
CollectionName: s.collection,
|
|
Limit: qdrant.PtrOf(uint32(limit)),
|
|
Filter: filter,
|
|
WithPayload: qdrant.NewWithPayload(true),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]qdrantPoint, 0, len(points))
|
|
for _, point := range points {
|
|
result = append(result, qdrantPoint{
|
|
ID: pointIDToString(point.GetId()),
|
|
Payload: valueMapToInterface(point.GetPayload()),
|
|
})
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]interface{}) error {
|
|
filter := buildQdrantFilter(filters)
|
|
if filter == nil {
|
|
return fmt.Errorf("delete all requires filters")
|
|
}
|
|
_, err := s.client.Delete(ctx, &qdrant.DeletePoints{
|
|
CollectionName: s.collection,
|
|
Wait: qdrant.PtrOf(true),
|
|
Points: qdrant.NewPointsSelectorFilter(filter),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *QdrantStore) ensureCollection(ctx context.Context, vectors map[string]int) error {
|
|
exists, err := s.client.CollectionExists(ctx, s.collection)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
return s.refreshCollectionSchema(ctx, vectors)
|
|
}
|
|
if len(vectors) > 0 {
|
|
params := make(map[string]*qdrant.VectorParams, len(vectors))
|
|
for name, dim := range vectors {
|
|
params[name] = &qdrant.VectorParams{
|
|
Size: uint64(dim),
|
|
Distance: qdrant.Distance_Cosine,
|
|
}
|
|
}
|
|
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
|
CollectionName: s.collection,
|
|
VectorsConfig: qdrant.NewVectorsConfigMap(params),
|
|
})
|
|
}
|
|
return s.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
|
CollectionName: s.collection,
|
|
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
|
|
Size: uint64(s.dimension),
|
|
Distance: qdrant.Distance_Cosine,
|
|
}),
|
|
})
|
|
}
|
|
|
|
func (s *QdrantStore) refreshCollectionSchema(ctx context.Context, vectors map[string]int) error {
|
|
info, err := s.client.GetCollectionInfo(ctx, s.collection)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
config := info.GetConfig()
|
|
if config == nil || config.GetParams() == nil || config.GetParams().GetVectorsConfig() == nil {
|
|
return nil
|
|
}
|
|
vectorsConfig := config.GetParams().GetVectorsConfig()
|
|
if vectorsConfig.GetParamsMap() != nil {
|
|
s.usesNamedVectors = true
|
|
s.vectorNames = map[string]int{}
|
|
for name, vec := range vectorsConfig.GetParamsMap().GetMap() {
|
|
if vec != nil {
|
|
s.vectorNames[name] = int(vec.GetSize())
|
|
}
|
|
}
|
|
if len(vectors) == 0 {
|
|
return nil
|
|
}
|
|
for name, dim := range vectors {
|
|
if existing, ok := s.vectorNames[name]; ok && existing == dim {
|
|
continue
|
|
}
|
|
return fmt.Errorf("collection missing vector %s (dim %d); migration required", name, dim)
|
|
}
|
|
return nil
|
|
}
|
|
s.usesNamedVectors = false
|
|
s.vectorNames = nil
|
|
return nil
|
|
}
|
|
|
|
func parseQdrantEndpoint(endpoint string) (string, int, bool, error) {
|
|
if endpoint == "" {
|
|
return "127.0.0.1", 6334, false, nil
|
|
}
|
|
if !strings.Contains(endpoint, "://") {
|
|
endpoint = "http://" + endpoint
|
|
}
|
|
parsed, err := url.Parse(endpoint)
|
|
if err != nil {
|
|
return "", 0, false, err
|
|
}
|
|
host := parsed.Hostname()
|
|
if host == "" {
|
|
host = "127.0.0.1"
|
|
}
|
|
port := 6334
|
|
if parsed.Port() != "" {
|
|
parsedPort, err := strconv.Atoi(parsed.Port())
|
|
if err != nil {
|
|
return "", 0, false, err
|
|
}
|
|
port = parsedPort
|
|
}
|
|
useTLS := parsed.Scheme == "https"
|
|
return host, port, useTLS, nil
|
|
}
|
|
|
|
func timeoutOrDefault(timeout time.Duration) time.Duration {
|
|
if timeout <= 0 {
|
|
return 10 * time.Second
|
|
}
|
|
return timeout
|
|
}
|
|
|
|
func buildQdrantFilter(filters map[string]interface{}) *qdrant.Filter {
|
|
if len(filters) == 0 {
|
|
return nil
|
|
}
|
|
conditions := make([]*qdrant.Condition, 0, len(filters))
|
|
for key, value := range filters {
|
|
if condition := buildQdrantCondition(key, value); condition != nil {
|
|
conditions = append(conditions, condition)
|
|
}
|
|
}
|
|
if len(conditions) == 0 {
|
|
return nil
|
|
}
|
|
return &qdrant.Filter{
|
|
Must: conditions,
|
|
}
|
|
}
|
|
|
|
func cloneFilters(filters map[string]interface{}) map[string]interface{} {
|
|
if len(filters) == 0 {
|
|
return map[string]interface{}{}
|
|
}
|
|
clone := make(map[string]interface{}, len(filters))
|
|
for key, value := range filters {
|
|
clone[key] = value
|
|
}
|
|
return clone
|
|
}
|
|
|
|
func buildQdrantCondition(key string, value interface{}) *qdrant.Condition {
|
|
switch typed := value.(type) {
|
|
case string:
|
|
return qdrant.NewMatch(key, typed)
|
|
case bool:
|
|
return qdrant.NewMatchBool(key, typed)
|
|
case int:
|
|
return qdrant.NewMatchInt(key, int64(typed))
|
|
case int64:
|
|
return qdrant.NewMatchInt(key, typed)
|
|
case float32:
|
|
v := float64(typed)
|
|
return qdrant.NewRange(key, &qdrant.Range{Gte: &v, Lte: &v})
|
|
case float64:
|
|
return qdrant.NewRange(key, &qdrant.Range{Gte: &typed, Lte: &typed})
|
|
case map[string]interface{}:
|
|
rangeValue := &qdrant.Range{}
|
|
for _, op := range []string{"gte", "gt", "lte", "lt"} {
|
|
if raw, ok := typed[op]; ok {
|
|
val, ok := toFloat(raw)
|
|
if !ok {
|
|
continue
|
|
}
|
|
switch op {
|
|
case "gte":
|
|
rangeValue.Gte = &val
|
|
case "gt":
|
|
rangeValue.Gt = &val
|
|
case "lte":
|
|
rangeValue.Lte = &val
|
|
case "lt":
|
|
rangeValue.Lt = &val
|
|
}
|
|
}
|
|
}
|
|
if rangeValue.Gte != nil || rangeValue.Gt != nil || rangeValue.Lte != nil || rangeValue.Lt != nil {
|
|
return qdrant.NewRange(key, rangeValue)
|
|
}
|
|
}
|
|
return qdrant.NewMatch(key, fmt.Sprint(value))
|
|
}
|
|
|
|
func toFloat(value interface{}) (float64, bool) {
|
|
switch typed := value.(type) {
|
|
case float32:
|
|
return float64(typed), true
|
|
case float64:
|
|
return typed, true
|
|
case int:
|
|
return float64(typed), true
|
|
case int64:
|
|
return float64(typed), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func pointIDToString(id *qdrant.PointId) string {
|
|
if id == nil {
|
|
return ""
|
|
}
|
|
if uuid := id.GetUuid(); uuid != "" {
|
|
return uuid
|
|
}
|
|
if num := id.GetNum(); num != 0 {
|
|
return fmt.Sprintf("%d", num)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func valueMapToInterface(values map[string]*qdrant.Value) map[string]interface{} {
|
|
result := make(map[string]interface{}, len(values))
|
|
for key, value := range values {
|
|
result[key] = valueToInterface(value)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func valueToInterface(value *qdrant.Value) interface{} {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
switch kind := value.GetKind().(type) {
|
|
case *qdrant.Value_NullValue:
|
|
return nil
|
|
case *qdrant.Value_BoolValue:
|
|
return kind.BoolValue
|
|
case *qdrant.Value_IntegerValue:
|
|
return kind.IntegerValue
|
|
case *qdrant.Value_DoubleValue:
|
|
return kind.DoubleValue
|
|
case *qdrant.Value_StringValue:
|
|
return kind.StringValue
|
|
case *qdrant.Value_StructValue:
|
|
return valueMapToInterface(kind.StructValue.GetFields())
|
|
case *qdrant.Value_ListValue:
|
|
items := make([]interface{}, 0, len(kind.ListValue.GetValues()))
|
|
for _, item := range kind.ListValue.GetValues() {
|
|
items = append(items, valueToInterface(item))
|
|
}
|
|
return items
|
|
default:
|
|
return nil
|
|
}
|
|
}
|